[Mlir-commits] [llvm] [mlir] [NFC][mlir][mesh, shard] Fixing misnomers in mesh dialect, renaming 'mesh' dialect to 'shard' (PR #150177)
Frank Schlimbach
llvmlistbot at llvm.org
Thu Jul 24 02:14:28 PDT 2025
https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/150177
>From b84d862a0adf684d8debeaa2a9bdd9d4b13a1e06 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 21 Jul 2025 15:00:41 +0200
Subject: [PATCH 1/2] Fixing misnomers in mesh dialect
- dialect name mesh -> shard
- (device) mesh -> (device) grid
- spmdize -> partition
---
mlir/docs/Dialects/{Mesh.md => Shard.md} | 34 +-
mlir/docs/Passes.md | 4 +-
mlir/include/mlir/Conversion/Passes.h | 2 +-
mlir/include/mlir/Conversion/Passes.td | 8 +-
.../MeshToMPI.h => ShardToMPI/ShardToMPI.h} | 10 +-
mlir/include/mlir/Dialect/CMakeLists.txt | 2 +-
...rdingExtensions.h => ShardingExtensions.h} | 2 +-
...nterfaceImpl.h => ShardingInterfaceImpl.h} | 10 +-
.../mlir/Dialect/Mesh/IR/CMakeLists.txt | 25 -
.../Dialect/Mesh/Transforms/CMakeLists.txt | 6 -
.../Dialect/{Mesh => Shard}/CMakeLists.txt | 0
.../mlir/Dialect/Shard/IR/CMakeLists.txt | 25 +
.../IR/MeshBase.td => Shard/IR/ShardBase.td} | 46 +-
.../MeshDialect.h => Shard/IR/ShardDialect.h} | 10 +-
.../IR/MeshOps.h => Shard/IR/ShardOps.h} | 146 +++--
.../IR/MeshOps.td => Shard/IR/ShardOps.td} | 420 +++++++-------
.../{Mesh => Shard}/Interfaces/CMakeLists.txt | 0
.../Interfaces/ShardingInterface.h | 58 +-
.../Interfaces/ShardingInterface.td | 48 +-
.../Interfaces/ShardingInterfaceImpl.h | 79 ++-
.../Dialect/Shard/Transforms/CMakeLists.txt | 6 +
.../Transforms/Partition.h} | 26 +-
.../{Mesh => Shard}/Transforms/Passes.h | 16 +-
.../{Mesh => Shard}/Transforms/Passes.td | 48 +-
.../Transforms/ReshardingPartitionDoc.md} | 144 ++---
.../Transforms/Simplifications.h | 14 +-
.../{Mesh => Shard}/Transforms/Transforms.h | 28 +-
...rdingExtensions.h => ShardingExtensions.h} | 2 +-
mlir/include/mlir/InitAllDialects.h | 4 +-
mlir/include/mlir/InitAllPasses.h | 4 +-
mlir/lib/Conversion/CMakeLists.txt | 2 +-
.../{MeshToMPI => ShardToMPI}/CMakeLists.txt | 8 +-
.../ShardToMPI.cpp} | 118 ++--
.../Dialect/Arith/Transforms/CMakeLists.txt | 2 +-
.../Transforms/ShardingInterfaceImpl.cpp | 36 +-
mlir/lib/Dialect/CMakeLists.txt | 2 +-
.../Dialect/Func/Extensions/AllExtensions.cpp | 2 +-
.../Dialect/Func/Extensions/CMakeLists.txt | 8 +-
...gExtensions.cpp => ShardingExtensions.cpp} | 8 +-
mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp | 6 +-
.../Linalg/Transforms/AllInterfaces.cpp | 4 +-
.../Dialect/Linalg/Transforms/CMakeLists.txt | 4 +-
...faceImpl.cpp => ShardingInterfaceImpl.cpp} | 167 +++---
.../Dialect/{Mesh => Shard}/CMakeLists.txt | 0
.../Dialect/{Mesh => Shard}/IR/CMakeLists.txt | 8 +-
.../IR/MeshOps.cpp => Shard/IR/ShardOps.cpp} | 526 +++++++++---------
.../{Mesh => Shard}/Interfaces/CMakeLists.txt | 4 +-
.../Interfaces/ShardingInterface.cpp | 270 +++++----
.../{Mesh => Shard}/Transforms/CMakeLists.txt | 10 +-
.../Transforms/Partition.cpp} | 393 +++++++------
.../Transforms/ShardingPropagation.cpp | 79 ++-
.../Transforms/Simplifications.cpp | 64 +--
.../{Mesh => Shard}/Transforms/Transforms.cpp | 82 +--
.../Transforms/TransformsDetail.h | 10 +-
.../Tensor/Extensions/AllExtensions.cpp | 2 +-
.../Dialect/Tensor/Extensions/CMakeLists.txt | 8 +-
...gExtensions.cpp => ShardingExtensions.cpp} | 38 +-
mlir/lib/Dialect/Tosa/CMakeLists.txt | 2 +-
.../Dialect/Tosa/IR/ShardingInterfaceImpl.cpp | 26 +-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 4 +-
.../convert-shard-to-mpi.mlir} | 66 +--
.../convert-shardshape-to-mpi.mlir | 40 +-
...mesh-spmdize.mlir => shard-partition.mlir} | 12 +-
.../Dialect/Arith/sharding-propagation.mlir | 50 +-
.../Linalg/mesh-sharding-propagation.mlir | 42 --
...-spmdization.mlir => shard-partition.mlir} | 118 ++--
.../Dialect/Linalg/sharding-propagation.mlir | 42 ++
mlir/test/Dialect/Mesh/canonicalization.mlir | 248 ---------
mlir/test/Dialect/Mesh/folding.mlir | 22 -
mlir/test/Dialect/Mesh/inlining.mlir | 15 -
.../Mesh/process-multi-index-op-lowering.mlir | 23 -
.../Dialect/Mesh/resharding-spmdization.mlir | 168 ------
.../Dialect/Mesh/sharding-propagation.mlir | 301 ----------
mlir/test/Dialect/Mesh/spmdization.mlir | 317 -----------
.../all-scatter-op-lowering.mlir | 40 +-
.../backward-sharding-propagation.mlir | 10 +-
mlir/test/Dialect/Shard/canonicalization.mlir | 248 +++++++++
mlir/test/Dialect/Shard/folding.mlir | 22 +
...forward-backward-sharding-propagation.mlir | 14 +-
.../forward-sharding-propagation.mlir | 34 +-
mlir/test/Dialect/Shard/inlining.mlir | 15 +
.../test/Dialect/{Mesh => Shard}/invalid.mlir | 442 +++++++--------
mlir/test/Dialect/{Mesh => Shard}/ops.mlir | 350 ++++++------
mlir/test/Dialect/Shard/partition.mlir | 317 +++++++++++
.../process-multi-index-op-lowering.mlir | 23 +
.../Dialect/Shard/resharding-partition.mlir | 168 ++++++
.../sharding-propagation-failed.mlir | 0
.../Dialect/Shard/sharding-propagation.mlir | 301 ++++++++++
.../{Mesh => Shard}/simplifications.mlir | 78 +--
.../test/Dialect/Tensor/mesh-spmdization.mlir | 52 --
mlir/test/Dialect/Tensor/shard-partition.mlir | 52 ++
mlir/test/lib/Dialect/CMakeLists.txt | 2 +-
.../Dialect/{Mesh => Shard}/CMakeLists.txt | 10 +-
.../{Mesh => Shard}/TestOpLowering.cpp | 18 +-
.../TestReshardingPartition.cpp} | 40 +-
.../{Mesh => Shard}/TestSimplifications.cpp | 24 +-
mlir/tools/mlir-opt/CMakeLists.txt | 2 +-
mlir/tools/mlir-opt/mlir-opt.cpp | 8 +-
.../llvm-project-overlay/mlir/BUILD.bazel | 146 ++---
.../mlir/test/BUILD.bazel | 8 +-
100 files changed, 3493 insertions(+), 3515 deletions(-)
rename mlir/docs/Dialects/{Mesh.md => Shard.md} (73%)
rename mlir/include/mlir/Conversion/{MeshToMPI/MeshToMPI.h => ShardToMPI/ShardToMPI.h} (64%)
rename mlir/include/mlir/Dialect/Func/Extensions/{MeshShardingExtensions.h => ShardingExtensions.h} (88%)
rename mlir/include/mlir/Dialect/Linalg/Transforms/{MeshShardingInterfaceImpl.h => ShardingInterfaceImpl.h} (54%)
delete mode 100644 mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
delete mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
rename mlir/include/mlir/Dialect/{Mesh => Shard}/CMakeLists.txt (100%)
create mode 100644 mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt
rename mlir/include/mlir/Dialect/{Mesh/IR/MeshBase.td => Shard/IR/ShardBase.td} (64%)
rename mlir/include/mlir/Dialect/{Mesh/IR/MeshDialect.h => Shard/IR/ShardDialect.h} (57%)
rename mlir/include/mlir/Dialect/{Mesh/IR/MeshOps.h => Shard/IR/ShardOps.h} (52%)
rename mlir/include/mlir/Dialect/{Mesh/IR/MeshOps.td => Shard/IR/ShardOps.td} (70%)
rename mlir/include/mlir/Dialect/{Mesh => Shard}/Interfaces/CMakeLists.txt (100%)
rename mlir/include/mlir/Dialect/{Mesh => Shard}/Interfaces/ShardingInterface.h (52%)
rename mlir/include/mlir/Dialect/{Mesh => Shard}/Interfaces/ShardingInterface.td (80%)
rename mlir/include/mlir/Dialect/{Mesh => Shard}/Interfaces/ShardingInterfaceImpl.h (58%)
create mode 100644 mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt
rename mlir/include/mlir/Dialect/{Mesh/Transforms/Spmdization.h => Shard/Transforms/Partition.h} (61%)
rename mlir/include/mlir/Dialect/{Mesh => Shard}/Transforms/Passes.h (75%)
rename mlir/include/mlir/Dialect/{Mesh => Shard}/Transforms/Passes.td (65%)
rename mlir/include/mlir/Dialect/{Mesh/Transforms/ReshardingSpmdizationDoc.md => Shard/Transforms/ReshardingPartitionDoc.md} (87%)
rename mlir/include/mlir/Dialect/{Mesh => Shard}/Transforms/Simplifications.h (93%)
rename mlir/include/mlir/Dialect/{Mesh => Shard}/Transforms/Transforms.h (65%)
rename mlir/include/mlir/Dialect/Tensor/Extensions/{MeshShardingExtensions.h => ShardingExtensions.h} (88%)
rename mlir/lib/Conversion/{MeshToMPI => ShardToMPI}/CMakeLists.txt (65%)
rename mlir/lib/Conversion/{MeshToMPI/MeshToMPI.cpp => ShardToMPI/ShardToMPI.cpp} (92%)
rename mlir/lib/Dialect/Func/Extensions/{MeshShardingExtensions.cpp => ShardingExtensions.cpp} (68%)
rename mlir/lib/Dialect/Linalg/Transforms/{MeshShardingInterfaceImpl.cpp => ShardingInterfaceImpl.cpp} (66%)
rename mlir/lib/Dialect/{Mesh => Shard}/CMakeLists.txt (100%)
rename mlir/lib/Dialect/{Mesh => Shard}/IR/CMakeLists.txt (59%)
rename mlir/lib/Dialect/{Mesh/IR/MeshOps.cpp => Shard/IR/ShardOps.cpp} (76%)
rename mlir/lib/Dialect/{Mesh => Shard}/Interfaces/CMakeLists.txt (76%)
rename mlir/lib/Dialect/{Mesh => Shard}/Interfaces/ShardingInterface.cpp (70%)
rename mlir/lib/Dialect/{Mesh => Shard}/Transforms/CMakeLists.txt (73%)
rename mlir/lib/Dialect/{Mesh/Transforms/Spmdization.cpp => Shard/Transforms/Partition.cpp} (66%)
rename mlir/lib/Dialect/{Mesh => Shard}/Transforms/ShardingPropagation.cpp (85%)
rename mlir/lib/Dialect/{Mesh => Shard}/Transforms/Simplifications.cpp (66%)
rename mlir/lib/Dialect/{Mesh => Shard}/Transforms/Transforms.cpp (78%)
rename mlir/lib/Dialect/{Mesh => Shard}/Transforms/TransformsDetail.h (82%)
rename mlir/lib/Dialect/Tensor/Extensions/{MeshShardingExtensions.cpp => ShardingExtensions.cpp} (74%)
rename mlir/test/Conversion/{MeshToMPI/convert-mesh-to-mpi.mlir => ShardToMPI/convert-shard-to-mpi.mlir} (90%)
rename mlir/test/Conversion/{MeshToMPI => ShardToMPI}/convert-shardshape-to-mpi.mlir (62%)
rename mlir/test/Dialect/Arith/{mesh-spmdize.mlir => shard-partition.mlir} (50%)
delete mode 100644 mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
rename mlir/test/Dialect/Linalg/{mesh-spmdization.mlir => shard-partition.mlir} (50%)
create mode 100644 mlir/test/Dialect/Linalg/sharding-propagation.mlir
delete mode 100644 mlir/test/Dialect/Mesh/canonicalization.mlir
delete mode 100644 mlir/test/Dialect/Mesh/folding.mlir
delete mode 100644 mlir/test/Dialect/Mesh/inlining.mlir
delete mode 100644 mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
delete mode 100644 mlir/test/Dialect/Mesh/resharding-spmdization.mlir
delete mode 100644 mlir/test/Dialect/Mesh/sharding-propagation.mlir
delete mode 100644 mlir/test/Dialect/Mesh/spmdization.mlir
rename mlir/test/Dialect/{Mesh => Shard}/all-scatter-op-lowering.mlir (72%)
rename mlir/test/Dialect/{Mesh => Shard}/backward-sharding-propagation.mlir (76%)
create mode 100644 mlir/test/Dialect/Shard/canonicalization.mlir
create mode 100644 mlir/test/Dialect/Shard/folding.mlir
rename mlir/test/Dialect/{Mesh => Shard}/forward-backward-sharding-propagation.mlir (63%)
rename mlir/test/Dialect/{Mesh => Shard}/forward-sharding-propagation.mlir (53%)
create mode 100644 mlir/test/Dialect/Shard/inlining.mlir
rename mlir/test/Dialect/{Mesh => Shard}/invalid.mlir (57%)
rename mlir/test/Dialect/{Mesh => Shard}/ops.mlir (55%)
create mode 100644 mlir/test/Dialect/Shard/partition.mlir
create mode 100644 mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir
create mode 100644 mlir/test/Dialect/Shard/resharding-partition.mlir
rename mlir/test/Dialect/{Mesh => Shard}/sharding-propagation-failed.mlir (100%)
create mode 100644 mlir/test/Dialect/Shard/sharding-propagation.mlir
rename mlir/test/Dialect/{Mesh => Shard}/simplifications.mlir (69%)
delete mode 100644 mlir/test/Dialect/Tensor/mesh-spmdization.mlir
create mode 100644 mlir/test/Dialect/Tensor/shard-partition.mlir
rename mlir/test/lib/Dialect/{Mesh => Shard}/CMakeLists.txt (51%)
rename mlir/test/lib/Dialect/{Mesh => Shard}/TestOpLowering.cpp (80%)
rename mlir/test/lib/Dialect/{Mesh/TestReshardingSpmdization.cpp => Shard/TestReshardingPartition.cpp} (75%)
rename mlir/test/lib/Dialect/{Mesh => Shard}/TestSimplifications.cpp (60%)
diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Shard.md
similarity index 73%
rename from mlir/docs/Dialects/Mesh.md
rename to mlir/docs/Dialects/Shard.md
index 5eb6569c7044b..714b340db4cde 100644
--- a/mlir/docs/Dialects/Mesh.md
+++ b/mlir/docs/Dialects/Shard.md
@@ -1,28 +1,28 @@
-# 'mesh' Dialect
+# 'shard' Dialect
-The `mesh` dialect contains a set of attributes, operations and interfaces that
-are useful for representing sharding and communication on a device mesh
+The `shard` dialect contains a set of attributes, operations and interfaces that
+are useful for representing sharding and communication on a device grid
cluster.
[TOC]
## Collective Communication Operations
-There are a number of operations in the Mesh dialect to facilitate
-communication between devices in a mesh.
+There are a number of operations in the Shard dialect to facilitate
+communication between devices in a grid.
It is assumed that the user is familiar with collective operations.
[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
explanation.
-The main addition is that the collectives in this dialect have mesh
+The main addition is that the collectives in this dialect have grid
semantics.
### Device groups
-The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
+The operation attributes `grid` and `grid_axes` specifies a list of device grid
axes that partition the devices into disjoint groups.
The collective operation is performed between devices in the same group.
-Devices that have the same coordinates outside of axes `mesh_axes` are in the
+Devices that have the same coordinates outside of axes `grid_axes` are in the
same group.
-A group is described by its multi-index along the axes outside of `mesh_axes`.
-For example if we have a device mesh of size `2x3x4x5` and the partition mesh
+A group is described by its multi-index along the axes outside of `grid_axes`.
+For example if we have a device grid of size `2x3x4x5` and the partition grid
axes list is `[0, 1]` then devices are partitioned into the groups
`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
The device groups would be `{ (k, m) | 0<=k<4, 0<=m<5 }`.
@@ -31,7 +31,7 @@ Device (1, 0, 2, 4) will be in another group.
Some collective operations like all-to-all and all-gather care about the
order of devices.
The order of device in a device group is induced by the order of axes in
-`mesh_axes`.
+`grid_axes`.
The axes are ordered from outer to inner.
If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
@@ -39,11 +39,11 @@ both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
### In-group Device
Some operations like `broadcast`, `scatter` and `send` specify devices in each
device-group.
-These devices are represented with their multi-index over the mesh axes that
+These devices are represented with their multi-index over the grid axes that
are not constant within a device group.
-These are the axes specified by `mesh_axes` attribute.
+These are the axes specified by `grid_axes` attribute.
-For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
+For Example on a 3D grid an operation with `grid_axes = [0, 2]` would specify
an in-group device with `(i, j)`. Then for each group with index `g` on the
second axis, the in-group device would be `(i, g, j)`.
### Purity
@@ -60,15 +60,15 @@ For example if a collective operation is optimized out, than it must also
not appear in any path of execution on any process.
Having the operations as `Pure` implies that if an interpreter is to execute
-the IR containing the `mesh` collectives, all processes would execute the same
+the IR containing the `grid` collectives, all processes would execute the same
line when they reach a pure collective operation.
This requirement stems from the need to be compatible with general optimization
passes like dead code and common sub-expression elimination.
## Operations
-[include "Dialects/MeshOps.md"]
+[include "Dialects/ShardOps.md"]
## Attributes
-[include "Dialects/MeshAttrs.md"]
+[include "Dialects/ShardAttrs.md"]
diff --git a/mlir/docs/Passes.md b/mlir/docs/Passes.md
index e9d22d1e3dfac..9df32666415bb 100644
--- a/mlir/docs/Passes.md
+++ b/mlir/docs/Passes.md
@@ -72,9 +72,9 @@ This document describes the available MLIR passes and their contracts.
[include "MemRefPasses.md"]
-## 'mesh' Dialect Passes
+## 'shard' Dialect Passes
-[include "MeshPasses.md"]
+[include "ShardPasses.md"]
## 'ml\_program' Dialect Passes
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index d93fbefab74aa..3dc48b2201cf2 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -52,7 +52,6 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
-#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
@@ -66,6 +65,7 @@
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
+#include "mlir/Conversion/ShardToMPI/ShardToMPI.h"
#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h"
#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
#include "mlir/Conversion/TosaToArith/TosaToArith.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 8183f355795a9..eb18160ea2eeb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -903,13 +903,13 @@ def ConvertMemRefToSPIRVPass : Pass<"convert-memref-to-spirv"> {
}
//===----------------------------------------------------------------------===//
-// MeshToMPI
+// ShardToMPI
//===----------------------------------------------------------------------===//
-def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
- let summary = "Convert Mesh dialect to MPI dialect.";
+def ConvertShardToMPIPass : Pass<"convert-shard-to-mpi"> {
+ let summary = "Convert Shard dialect to MPI dialect.";
let description = [{
- This pass converts communication operations from the Mesh dialect to the
+ This pass converts communication operations from the Shard dialect to the
MPI dialect.
If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
use that integer value instead of calling MPI_Comm_rank. This allows
diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/ShardToMPI/ShardToMPI.h
similarity index 64%
rename from mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
rename to mlir/include/mlir/Conversion/ShardToMPI/ShardToMPI.h
index bc64e7a3c1c8c..b1aa08c432249 100644
--- a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
+++ b/mlir/include/mlir/Conversion/ShardToMPI/ShardToMPI.h
@@ -1,4 +1,4 @@
-//===- MeshToMPI.h - Convert Mesh to MPI dialect ----------------*- C++ -*-===//
+//===- ShardToMPI.h - Convert Shard to MPI dialect --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
-#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+#ifndef MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H
+#define MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
@@ -15,9 +15,9 @@
namespace mlir {
class Pass;
-#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS
+#define GEN_PASS_DECL_CONVERTSHARDTOMPIPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
-#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+#endif // MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 56dc97282fa4a..e27b1679c2a52 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -19,7 +19,7 @@ add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(Math)
add_subdirectory(MemRef)
-add_subdirectory(Mesh)
+add_subdirectory(Shard)
add_subdirectory(MLProgram)
add_subdirectory(MPI)
add_subdirectory(NVGPU)
diff --git a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Func/Extensions/ShardingExtensions.h
similarity index 88%
rename from mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
rename to mlir/include/mlir/Dialect/Func/Extensions/ShardingExtensions.h
index 30d3033209d21..e22b24b3446bb 100644
--- a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
+++ b/mlir/include/mlir/Dialect/Func/Extensions/ShardingExtensions.h
@@ -1,4 +1,4 @@
-//===- MeshShardingExtensions.h - -----------------------------------------===//
+//===- ShardingExtensions.h - -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h
similarity index 54%
rename from mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h
rename to mlir/include/mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h
index c57501ea86b7e..dc21bc05a2dc1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h
@@ -1,4 +1,4 @@
-//===- MeshShardingInterfaceImpl.h ----------------------------------------===//
+//===- ShardingInterfaceImpl.h ----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,15 +6,15 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
-#define MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
+#ifndef MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H
+#define MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H
namespace mlir {
class DialectRegistry;
namespace linalg {
-void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry);
+void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
} // namespace linalg
} // namespace mlir
-#endif // MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
+#endif // MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
deleted file mode 100644
index f26c6285efd89..0000000000000
--- a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
+++ /dev/null
@@ -1,25 +0,0 @@
-add_mlir_doc(MeshOps MeshOps Dialects/ -gen-op-doc -dialect=mesh)
-add_mlir_doc(MeshOps MeshAttrs Dialects/ -gen-attrdef-doc -dialect=mesh)
-
-set(LLVM_TARGET_DEFINITIONS MeshOps.td)
-mlir_tablegen(MeshDialect.cpp.inc -gen-dialect-defs -dialect=mesh)
-mlir_tablegen(MeshDialect.h.inc -gen-dialect-decls -dialect=mesh)
-
-set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshAttributes.h.inc -gen-attrdef-decls)
-mlir_tablegen(MeshAttributes.cpp.inc -gen-attrdef-defs)
-
-set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
-mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
-
-set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshTypes.h.inc -gen-typedef-decls)
-mlir_tablegen(MeshTypes.cpp.inc -gen-typedef-defs)
-
-set(LLVM_TARGET_DEFINITIONS MeshOps.td)
-mlir_tablegen(MeshOps.h.inc -gen-op-decls)
-mlir_tablegen(MeshOps.cpp.inc -gen-op-defs)
-
-add_public_tablegen_target(MLIRMeshIncGen)
-add_dependencies(mlir-headers MLIRMeshIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
deleted file mode 100644
index 8d768485103b6..0000000000000
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
+++ /dev/null
@@ -1,6 +0,0 @@
-set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
-add_public_tablegen_target(MLIRMeshPassIncGen)
-add_dependencies(mlir-headers MLIRMeshPassIncGen)
-
-add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/CMakeLists.txt
similarity index 100%
rename from mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
rename to mlir/include/mlir/Dialect/Shard/CMakeLists.txt
diff --git a/mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..a2495af135899
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt
@@ -0,0 +1,25 @@
+add_mlir_doc(ShardOps ShardOps Dialects/ -gen-op-doc -dialect=shard)
+add_mlir_doc(ShardOps ShardAttrs Dialects/ -gen-attrdef-doc -dialect=shard)
+
+set(LLVM_TARGET_DEFINITIONS ShardOps.td)
+mlir_tablegen(ShardDialect.cpp.inc -gen-dialect-defs -dialect=shard)
+mlir_tablegen(ShardDialect.h.inc -gen-dialect-decls -dialect=shard)
+
+set(LLVM_TARGET_DEFINITIONS ShardBase.td)
+mlir_tablegen(ShardAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(ShardAttributes.cpp.inc -gen-attrdef-defs)
+
+set(LLVM_TARGET_DEFINITIONS ShardBase.td)
+mlir_tablegen(ShardEnums.h.inc -gen-enum-decls)
+mlir_tablegen(ShardEnums.cpp.inc -gen-enum-defs)
+
+set(LLVM_TARGET_DEFINITIONS ShardBase.td)
+mlir_tablegen(ShardTypes.h.inc -gen-typedef-decls)
+mlir_tablegen(ShardTypes.cpp.inc -gen-typedef-defs)
+
+set(LLVM_TARGET_DEFINITIONS ShardOps.td)
+mlir_tablegen(ShardOps.h.inc -gen-op-decls)
+mlir_tablegen(ShardOps.cpp.inc -gen-op-defs)
+
+add_public_tablegen_target(MLIRShardIncGen)
+add_dependencies(mlir-headers MLIRShardIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Shard/IR/ShardBase.td
similarity index 64%
rename from mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
rename to mlir/include/mlir/Dialect/Shard/IR/ShardBase.td
index 61403ac178980..41ae31807c825 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardBase.td
@@ -1,4 +1,4 @@
-//===- MeshBase.td - Mesh Dialect --------------------------*- tablegen -*-===//
+//===- ShardBase.td - Shard Dialect ------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_MESH_IR_MESHBASE_TD
-#define MLIR_DIALECT_MESH_IR_MESHBASE_TD
+#ifndef MLIR_DIALECT_SHARD_IR_SHARDBASE_TD
+#define MLIR_DIALECT_SHARD_IR_SHARDBASE_TD
include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
@@ -16,15 +16,15 @@ include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/EnumAttr.td"
//===----------------------------------------------------------------------===//
-// Mesh Dialect
+// Shard Dialect
//===----------------------------------------------------------------------===//
-def Mesh_Dialect : Dialect {
- let name = "mesh";
- let cppNamespace = "::mlir::mesh";
+def Shard_Dialect : Dialect {
+ let name = "shard";
+ let cppNamespace = "::mlir::shard";
let description = [{
- See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
+ See [Shard dialect documentation](mlir/docs/Dialects/Shard.md).
}];
let dependentDialects = [
@@ -36,16 +36,16 @@ def Mesh_Dialect : Dialect {
let hasConstantMaterializer = 1;
}
-def Mesh_MeshAxis : I<16>;
-def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
-def Mesh_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
+def Shard_GridAxis : I<16>;
+def Shard_GridAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+def Shard_ShardShapeAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
//===----------------------------------------------------------------------===//
-// Mesh Enums.
+// Shard Enums.
//===----------------------------------------------------------------------===//
-def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
- "Reduction of an iterator/mesh dimension.", [
+def Shard_ReductionKind : I32EnumAttr<"ReductionKind",
+ "Reduction of an iterator/grid dimension.", [
I32EnumAttrCase<"Sum", 1, "sum">,
I32EnumAttrCase<"Max", 2, "max">,
I32EnumAttrCase<"Min", 3, "min">,
@@ -58,31 +58,31 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
I32EnumAttrCase<"Generic", 100, "generic">
]> {
let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::mesh";
+ let cppNamespace = "::mlir::shard";
}
-def Mesh_ReductionKindAttr : EnumAttr<Mesh_Dialect, Mesh_ReductionKind, "partial"> {
+def Shard_ReductionKindAttr : EnumAttr<Shard_Dialect, Shard_ReductionKind, "partial"> {
let assemblyFormat = "$value";
}
-class Mesh_Type<string name, string typeMnemonic, list<Trait> traits = [],
+class Shard_Type<string name, string typeMnemonic, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
- : TypeDef<Mesh_Dialect, name, traits, baseCppClass> {
+ : TypeDef<Shard_Dialect, name, traits, baseCppClass> {
let mnemonic = typeMnemonic;
}
-def Mesh_Sharding : Mesh_Type<"Sharding", "sharding"> {
+def Shard_Sharding : Shard_Type<"Sharding", "sharding"> {
let summary = "sharding definition";
let assemblyFormat = "";
}
//===----------------------------------------------------------------------===//
-// Mesh Attribute
+// Shard Attribute
//===----------------------------------------------------------------------===//
-def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
+def Shard_GridAxesArrayAttr : AttrDef<Shard_Dialect, "GridAxesArray"> {
let mnemonic = "axisarray";
- let parameters = (ins ArrayRefParameter<"MeshAxesAttr">:$axes);
+ let parameters = (ins ArrayRefParameter<"GridAxesAttr">:$axes);
let assemblyFormat = "`[` $axes `]`";
let extraClassDeclaration = [{
size_t size() const { return getAxes().size(); }
@@ -91,4 +91,4 @@ def Mesh_MeshAxesArrayAttr : AttrDef<Mesh_Dialect, "MeshAxesArray"> {
}];
}
-#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD
+#endif // MLIR_DIALECT_SHARD_IR_SHARDBASE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h b/mlir/include/mlir/Dialect/Shard/IR/ShardDialect.h
similarity index 57%
rename from mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
rename to mlir/include/mlir/Dialect/Shard/IR/ShardDialect.h
index a30cf91e851fe..4113a668d4b76 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardDialect.h
@@ -1,4 +1,4 @@
-//===- MeshOps.h - Mesh Dialect ---------------------------------*- C++ -*-===//
+//===- ShardOps.h - Shard Dialect -------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,11 +6,11 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_MESH_IR_MESHDIALECT_H
-#define MLIR_DIALECT_MESH_IR_MESHDIALECT_H
+#ifndef MLIR_DIALECT_SHARD_IR_SHARDDIALECT_H
+#define MLIR_DIALECT_SHARD_IR_SHARDDIALECT_H
#include "mlir/IR/Dialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h.inc"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h.inc"
-#endif // MLIR_DIALECT_MESH_IR_MESHDIALECT_H
+#endif // MLIR_DIALECT_SHARD_IR_SHARDDIALECT_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.h
similarity index 52%
rename from mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
rename to mlir/include/mlir/Dialect/Shard/IR/ShardOps.h
index 7cfe59dd957ca..457fe6f6b8d0a 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.h
@@ -1,4 +1,4 @@
-//===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===//
+//===- ShardOps.h - Shard Dialect Operations --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
-#define MLIR_DIALECT_MESH_IR_MESHOPS_H
+#ifndef MLIR_DIALECT_SHARD_IR_SHARDOPS_H
+#define MLIR_DIALECT_SHARD_IR_SHARDOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
@@ -21,45 +21,45 @@
#include "llvm/Support/MathExtras.h"
namespace mlir {
-namespace mesh {
+namespace shard {
-using MeshAxis = int16_t;
-using MeshAxesAttr = DenseI16ArrayAttr;
+using GridAxis = int16_t;
+using GridAxesAttr = DenseI16ArrayAttr;
using ShardShapeAttr = DenseI64ArrayAttr;
using HaloSizePairAttr = DenseI64ArrayAttr;
-} // namespace mesh
+} // namespace shard
} // namespace mlir
-#include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
+#include "mlir/Dialect/Shard/IR/ShardEnums.h.inc"
#define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
+#include "mlir/Dialect/Shard/IR/ShardAttributes.h.inc"
namespace mlir {
-namespace mesh {
+namespace shard {
-class MeshSharding {
+class Sharding {
private:
- ::mlir::FlatSymbolRefAttr mesh;
- SmallVector<MeshAxesAttr> split_axes;
+ ::mlir::FlatSymbolRefAttr grid;
+ SmallVector<GridAxesAttr> split_axes;
SmallVector<int64_t> static_halo_sizes;
SmallVector<int64_t> static_sharded_dims_offsets;
SmallVector<Value> dynamic_halo_sizes;
SmallVector<Value> dynamic_sharded_dims_offsets;
public:
- MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr);
- MeshSharding(Value rhs);
- static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
- ArrayRef<MeshAxesAttr> split_axes_,
- ArrayRef<int64_t> static_halo_sizes_ = {},
- ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
- ArrayRef<Value> dynamic_halo_sizes_ = {},
- ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
- ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
- ::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
- ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
+ Sharding(::mlir::FlatSymbolRefAttr grid_ = nullptr);
+ Sharding(Value rhs);
+ static Sharding get(::mlir::FlatSymbolRefAttr grid_,
+ ArrayRef<GridAxesAttr> split_axes_,
+ ArrayRef<int64_t> static_halo_sizes_ = {},
+ ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
+ ArrayRef<Value> dynamic_halo_sizes_ = {},
+ ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
+ ::mlir::FlatSymbolRefAttr getGridAttr() const { return grid; }
+ ::llvm::StringRef getGrid() const { return grid ? grid.getValue() : ""; }
+ ArrayRef<GridAxesAttr> getSplitAxes() const { return split_axes; }
ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
ArrayRef<int64_t> getStaticShardedDimsOffsets() const {
return static_sharded_dims_offsets;
@@ -68,28 +68,28 @@ class MeshSharding {
ArrayRef<Value> getDynamicShardedDimsOffsets() const {
return dynamic_sharded_dims_offsets;
}
- operator bool() const { return (!mesh) == false; }
+ operator bool() const { return (!grid) == false; }
bool operator==(Value rhs) const;
bool operator!=(Value rhs) const;
- bool operator==(const MeshSharding &rhs) const;
- bool operator!=(const MeshSharding &rhs) const;
- bool equalSplitAxes(const MeshSharding &rhs) const;
- bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
- bool equalHaloSizes(const MeshSharding &rhs) const;
- bool equalShardSizes(const MeshSharding &rhs) const;
+ bool operator==(const Sharding &rhs) const;
+ bool operator!=(const Sharding &rhs) const;
+ bool equalSplitAxes(const Sharding &rhs) const;
+ bool equalHaloAndShardSizes(const Sharding &rhs) const;
+ bool equalHaloSizes(const Sharding &rhs) const;
+ bool equalShardSizes(const Sharding &rhs) const;
};
-} // namespace mesh
+} // namespace shard
} // namespace mlir
#define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
+#include "mlir/Dialect/Shard/IR/ShardTypes.h.inc"
#define GET_OP_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
+#include "mlir/Dialect/Shard/IR/ShardOps.h.inc"
namespace mlir {
-namespace mesh {
+namespace shard {
inline bool isReductionLoop(utils::IteratorType iType) {
return iType == utils::IteratorType::reduction;
@@ -103,52 +103,52 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
}
// Is the same tensor replicated on all processes.
-inline bool isFullReplication(MeshSharding sharding) {
- return llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
+inline bool isFullReplication(Sharding sharding) {
+ return llvm::all_of(sharding.getSplitAxes(), [](GridAxesAttr axes) {
return axes.asArrayRef().empty();
});
}
-inline mesh::MeshOp
-getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
+inline shard::GridOp
+getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol,
SymbolTableCollection &symbolTableCollection) {
- if (!meshSymbol)
+ if (!gridSymbol)
return nullptr;
- return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
- op, meshSymbol);
+ return symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
+ op, gridSymbol);
}
-inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
- SymbolTableCollection &symbolTableCollection) {
- mesh::MeshOp meshOp = getMeshOrNull(op, meshSymbol, symbolTableCollection);
- assert(meshOp);
- return meshOp;
+inline shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol,
+ SymbolTableCollection &symbolTableCollection) {
+ shard::GridOp gridOp = getGridOrNull(op, gridSymbol, symbolTableCollection);
+ assert(gridOp);
+ return gridOp;
}
-// Get the corresponding mesh op using the standard attribute nomenclature.
+// Get the corresponding grid op using the standard attribute nomenclature.
template <typename Op>
-mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
- return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
+shard::GridOp getGrid(Op op, SymbolTableCollection &symbolTableCollection) {
+ return getGrid(op.getOperation(), op.getGridAttr(), symbolTableCollection);
}
template <>
-inline mesh::MeshOp
-getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
- return getMesh(
+inline shard::GridOp
+getGrid<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
+ return getGrid(
op.getOperation(),
- cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
+ cast<ShardingOp>(op.getSharding().getDefiningOp()).getGridAttr(),
symbolTableCollection);
}
// Get the number of processes that participate in each group
-// induced by `meshAxes`.
-template <typename MeshAxesRange, typename MeshShapeRange>
-int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
- MeshShapeRange &&meshShape) {
+// induced by `gridAxes`.
+template <typename GridAxesRange, typename GridShapeRange>
+int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes,
+ GridShapeRange &&gridShape) {
int64_t res = 1;
- for (MeshAxis axis : meshAxes) {
- auto axisSize = *(std::begin(meshShape) + axis);
+ for (GridAxis axis : gridAxes) {
+ auto axisSize = *(std::begin(gridShape) + axis);
if (ShapedType::isDynamic(axisSize)) {
return ShapedType::kDynamic;
}
@@ -158,10 +158,10 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
return res;
}
-template <typename MeshAxesRange>
-int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
- return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
- mesh.getShape());
+template <typename GridAxesRange>
+int64_t collectiveProcessGroupSize(GridAxesRange &&gridAxes, GridOp grid) {
+ return collectiveProcessGroupSize(std::forward<GridAxesRange>(gridAxes),
+ grid.getShape());
}
// Get the size of a sharded dimension.
@@ -182,27 +182,25 @@ inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
}
// Return the sharded shape `shape` according ot sharding `sharding`.
-// The shape for the tensor on each device in the mesh.
+// The shape for the tensor on each device in the grid.
// Example:
-// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
+// On a 2x4x? grid with split axes = [[0], [1], [2]] the shape ?x5x1 would
// result in a shape for each shard of ?x2x?.
-ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
- MeshSharding sharding);
+ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding);
// If ranked tensor type return its sharded counterpart.
//
// If not ranked tensor type return `type`.
// `sharding` in that case must be null.
-Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
+Type shardType(Type type, GridOp grid, Sharding sharding);
// Insert shard op if there is not one that already has the same sharding.
// Use newShardOp if it is not null. Otherwise create a new one.
// May insert resharding if required.
// Potentially updates newShardOp.
-void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
+void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result,
OpBuilder &builder);
-void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
- OpOperand &operand,
+void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand,
OpBuilder &builder);
/// Converts a vector of OpFoldResults (ints) into vector of Values of the
@@ -210,7 +208,7 @@ void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
llvm::ArrayRef<int64_t> statics,
ValueRange dynamics, Type type = Type());
-} // namespace mesh
+} // namespace shard
} // namespace mlir
-#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
+#endif // MLIR_DIALECT_SHARD_IR_SHARDOPS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
similarity index 70%
rename from mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
rename to mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 1662885c161e6..29b384f401876 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -1,4 +1,4 @@
-//===-- MeshOps.td - Mesh dialect operation definitions ----*- tablegen -*-===//
+//===-- ShardOps.td - Shard dialect operation definitions ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_TD
-#define MLIR_DIALECT_MESH_IR_MESHOPS_TD
+#ifndef MLIR_DIALECT_SHARD_IR_SHARDOPS_TD
+#define MLIR_DIALECT_SHARD_IR_SHARDOPS_TD
-include "mlir/Dialect/Mesh/IR/MeshBase.td"
+include "mlir/Dialect/Shard/IR/ShardBase.td"
include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -21,24 +21,24 @@ include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
-// Mesh operations.
+// Shard operations.
//===----------------------------------------------------------------------===//
-class Mesh_Op<string mnemonic, list<Trait> traits = []> :
- Op<Mesh_Dialect, mnemonic, traits> {
+class Shard_Op<string mnemonic, list<Trait> traits = []> :
+ Op<Shard_Dialect, mnemonic, traits> {
}
-def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
- let summary = "Description of a device/process mesh.";
+def Shard_GridOp : Shard_Op<"grid", [Symbol, Pure]> {
+ let summary = "Description of a device/process grid.";
let description = [{
- The mesh.mesh operation is a symbol operation that identifies a specific
- mesh. The operation has three attributes:
+ The shard.grid operation is a symbol operation that identifies a specific
+ grid. The operation has three attributes:
- 1. `sym_name`: This attribute uniquely identifies the name of the mesh.
- This name serves as a symbolic reference to the mesh throughout
+ 1. `sym_name`: This attribute uniquely identifies the name of the grid.
+ This name serves as a symbolic reference to the grid throughout
the MLIR module, allowing for consistent referencing and easier debugging.
- 2. `shape`: This attribute represents the shape of the device mesh.
+ 2. `shape`: This attribute represents the shape of the device grid.
It uses the same notation as a tensor shape. Also allowing for dynamic
dimensions.
This flexibility allows for dynamic device assignment or configurations
@@ -48,21 +48,21 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
Example:
```
- // A device mesh with 3 axes, the total device number is 4 * 8 * 12
+ // A device grid with 3 axes, the total device number is 4 * 8 * 12
// The dimension sizes are 4, 8, 12
- mesh.mesh @mesh0(shape = 4x8x12)
+ shard.grid @grid0(shape = 4x8x12)
- // A device mesh with 2 axes, the total device number is unknown
+ // A device grid with 2 axes, the total device number is unknown
// The first dimension size is 4 and the second is unknown
- mesh.mesh @mesh1(shape = 4x?)
+ shard.grid @grid1(shape = 4x?)
- // A device mesh with 2 axes, the total device number is unknown
+ // A device grid with 2 axes, the total device number is unknown
// The first dimension size is unknown and the second is 4
- mesh.mesh @mesh2(shape = ?x4)
+ shard.grid @grid2(shape = ?x4)
- // A device mesh with 2 axes, the number of devices along both axes
+ // A device grid with 2 axes, the number of devices along both axes
// is unknown
- mesh.mesh @mesh3(shape = ?x?)
+ shard.grid @grid3(shape = ?x?)
```
}];
let arguments = (ins
@@ -79,15 +79,15 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol, Pure]> {
let hasVerifier = 1;
}
-def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
+def Shard_GridShapeOp : Shard_Op<"grid_shape", [
Pure,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
- let summary = "Get the shape of the mesh.";
+ let summary = "Get the shape of the grid.";
let arguments = (ins
- FlatSymbolRefAttr:$mesh,
- DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+ FlatSymbolRefAttr:$grid,
+ DefaultValuedAttr<Shard_GridAxesAttr, "{}">:$axes
);
let results = (outs
@@ -95,46 +95,46 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
);
let assemblyFormat = [{
- $mesh (`axes` `=` $axes^)?
+ $grid (`axes` `=` $axes^)?
attr-dict `:` type($result)
}];
let builders = [
- OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
- OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef<MeshAxis>":$axes)>,
- OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
+ OpBuilder<(ins "::mlir::shard::GridOp":$grid)>,
+ OpBuilder<(ins "::mlir::shard::GridOp":$grid, "ArrayRef<GridAxis>":$axes)>,
+ OpBuilder<(ins "StringRef":$grid, "ArrayRef<GridAxis>":$axes)>
];
}
-def Mesh_ProcessMultiIndexOp : Mesh_Op<"process_multi_index", [
+def Shard_ProcessMultiIndexOp : Shard_Op<"process_multi_index", [
Pure,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
- let summary = "Get the multi index of current device along specified mesh axes.";
+ let summary = "Get the multi index of current device along specified grid axes.";
let description = [{
It is used in the SPMD format of IR.
- The `axes` mush be non-negative and less than the total number of mesh axes.
+ The `axes` mush be non-negative and less than the total number of grid axes.
If the axes are empty then get the index along all axes.
}];
let arguments = (ins
- FlatSymbolRefAttr:$mesh,
- DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+ FlatSymbolRefAttr:$grid,
+ DefaultValuedAttr<Shard_GridAxesAttr, "{}">:$axes
);
let results = (outs
Variadic<Index>:$result
);
let assemblyFormat = [{
- `on` $mesh (`axes` `=` $axes^)?
+ `on` $grid (`axes` `=` $axes^)?
attr-dict `:` type($result)
}];
let builders = [
- OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
- OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
+ OpBuilder<(ins "::mlir::shard::GridOp":$grid)>,
+ OpBuilder<(ins "StringRef":$grid, "ArrayRef<GridAxis>":$axes)>
];
}
-def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
+def Shard_ProcessLinearIndexOp : Shard_Op<"process_linear_index", [
Pure,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
@@ -143,34 +143,34 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
let description = [{
Example:
```
- %idx = mesh.process_linear_index on @mesh : index
+ %idx = shard.process_linear_index on @grid : index
```
- if `@mesh` has shape `(10, 20, 30)`, a device with multi
+ if `@grid` has shape `(10, 20, 30)`, a device with multi
index `(1, 2, 3)` will have linear index `3 + 30*2 + 20*30*1`.
}];
- let arguments = (ins FlatSymbolRefAttr:$mesh);
+ let arguments = (ins FlatSymbolRefAttr:$grid);
let results = (outs Index:$result);
- let assemblyFormat = "`on` $mesh attr-dict `:` type($result)";
+ let assemblyFormat = "`on` $grid attr-dict `:` type($result)";
let builders = [
- OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>
+ OpBuilder<(ins "::mlir::shard::GridOp":$grid)>
];
}
-def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
+def Shard_NeighborsLinearIndicesOp : Shard_Op<"neighbors_linear_indices", [
Pure,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary =
- "For given mesh index get the linear indices of the direct neighbor processes along the given split.";
+ "For given grid index get the linear indices of the direct neighbor processes along the given split.";
let description = [{
Example:
```
- mesh.mesh @mesh0(shape = 10x20x30)
+ shard.grid @grid0(shape = 10x20x30)
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
- %idx = mesh.neighbors_linear_indices on @mesh[%c1, %c2, %c3] split_axes = [1] : index
+ %idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : index
```
The above returns two indices, `633` and `693`, which correspond to the
index of the previous process `(1, 1, 3)`, and the next process
@@ -179,12 +179,12 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
A negative value is returned if there is no neighbor in the respective
direction along the given `split_axes`.
}];
- let arguments = (ins FlatSymbolRefAttr:$mesh,
+ let arguments = (ins FlatSymbolRefAttr:$grid,
Variadic<Index>:$device,
- Mesh_MeshAxesAttr:$split_axes);
+ Shard_GridAxesAttr:$split_axes);
let results = (outs Index:$neighbor_down, Index:$neighbor_up);
let assemblyFormat = [{
- `on` $mesh `[` $device `]`
+ `on` $grid `[` $device `]`
`split_axes` `=` $split_axes
attr-dict `:` type(results)
}];
@@ -194,7 +194,7 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
// Sharding operations.
//===----------------------------------------------------------------------===//
-def Mesh_ShardingOp : Mesh_Op<"sharding", [
+def Shard_ShardingOp : Shard_Op<"sharding", [
Pure,
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
@@ -202,18 +202,18 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
]> {
let summary = "Define a sharding of a tensor.";
let description = [{
- The MeshSharding specifies how a tensor is sharded and distributed across the
- process mesh. It is typically used in a `mesh.shard` operation.
+ The Sharding specifies how a tensor is sharded and distributed across the
+ process shard. It is typically used in a `shard.shard` operation.
The operation has the following attributes and operands:
- 1. `mesh`: this attribute is a FlatSymbolRefAttr that refers to the device
- mesh where the distributed tensor is placed. The symbol must resolve to a
- `mesh.mesh` operation.
+ 1. `grid`: this attribute is a FlatSymbolRefAttr that refers to the device
+ grid where the distributed tensor is placed. The symbol must resolve to a
+ `shard.grid` operation.
2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
maximum size is the `rank` of the related tensor. For the i-th sub-array, if
its value is [x, y], it indicates that the tensor's i-th dimension is splitted
- along the x and y axes of the device mesh.
+ along the x and y axes of the device grid.
3. [Optional] Sizes of halos to be added for each sharded tensor dimension.
`halo_sizes` is provided as a flattened 1d array of i64s, 2 values for each
@@ -233,7 +233,7 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
`sharded_dims_offsets` = [0, 24, 32, 0, 20, 32] means that the first device of
- the device-mesh will get a shard of shape 24x20x32 and the second device will get
+ the device-grid will get a shard of shape 24x20x32 and the second device will get
a shard of shape 8x12x32. `?` indicates dynamic shard dimensions.
`halo_sizes` and `sharded_dims_offsets` are mutually exclusive.
@@ -241,101 +241,101 @@ def Mesh_ShardingOp : Mesh_Op<"sharding", [
Examples:
```
- mesh.mesh @mesh0(shape = 2x2x4)
- mesh.mesh @mesh1d_4(shape = 4)
+ shard.grid @grid0(shape = 2x2x4)
+ shard.grid @grid1d_4(shape = 4)
- // The tensor is fully replicated on @mesh0.
+ // The tensor is fully replicated on @grid0.
// Currently, there must be at least one sub-array present in axes, even
// if it's empty. Otherwise, a parsing error will occur.
- %sharding0 = mesh.sharding @mesh0 split_axes = [[]]
+ %sharding0 = shard.sharding @grid0 split_axes = [[]]
- // The tensor is sharded on the first dimension along axis 0 of @mesh0
- %sharding1 = mesh.sharding @mesh0 split_axes = [[0]]
+ // The tensor is sharded on the first dimension along axis 0 of @grid0
+ %sharding1 = shard.sharding @grid0 split_axes = [[0]]
- // Could be used for a mesh.shard op
- %sharded0 = mesh.shard %arg0 to %sharding3 : tensor<4x8xf32>
+ // Could be used for a shard.shard op
+ %sharded0 = shard.shard %arg0 to %sharding3 : tensor<4x8xf32>
- // The tensor is sharded on its first dimension along axis 0 of @mesh0 and
+ // The tensor is sharded on its first dimension along axis 0 of @grid0 and
// and it has halo-sizes of 1 and 2 on the sharded dim.
- %halo_sharding = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2]
- %sharded1 = mesh.shard %arg0 to %halo_sharding : tensor<4x8xf32>
+ %halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2]
+ %sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32>
- // The tensor is sharded on its second dimension along axis 0 of @mesh1d_4
+ // The tensor is sharded on its second dimension along axis 0 of @grid1d_4
// and it has pre-defined shard sizes. The shards of the devices will have
// the following shapes: [4x2, 4x3, 4x4, 4x5]
- %sharding4 = mesh.sharding @mesh1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
- %sharded2 = mesh.shard %arg0 to %sharding4 : tensor<4x14xf32>
+ %sharding4 = shard.sharding @grid1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14]
+ %sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32>
```
}];
let arguments = (ins
- FlatSymbolRefAttr:$mesh,
- Mesh_MeshAxesArrayAttr:$split_axes,
+ FlatSymbolRefAttr:$grid,
+ Shard_GridAxesArrayAttr:$split_axes,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_sharded_dims_offsets,
Variadic<I64>:$dynamic_sharded_dims_offsets,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes,
Variadic<I64>:$dynamic_halo_sizes
);
let results = (outs
- Mesh_Sharding:$result
+ Shard_Sharding:$result
);
let assemblyFormat = [{
- $mesh
+ $grid
`split_axes` `=` $split_axes
(`halo_sizes` `=` custom<DynamicIndexList>($dynamic_halo_sizes, $static_halo_sizes)^)?
(`sharded_dims_offsets` `=` custom<DynamicIndexList>($dynamic_sharded_dims_offsets, $static_sharded_dims_offsets)^)?
attr-dict `:` type($result)
}];
let builders = [
- OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
- "ArrayRef<MeshAxesAttr>":$split_axes,
+ OpBuilder<(ins "FlatSymbolRefAttr":$grid,
+ "ArrayRef<GridAxesAttr>":$split_axes,
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets)>,
- OpBuilder<(ins "FlatSymbolRefAttr":$mesh,
- "ArrayRef<MeshAxesAttr>":$split_axes,
+ OpBuilder<(ins "FlatSymbolRefAttr":$grid,
+ "ArrayRef<GridAxesAttr>":$split_axes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$halo_sizes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sharded_dims_offsets)>,
- OpBuilder<(ins "llvm::StringRef":$mesh,
- "ArrayRef<MeshAxesAttr>":$split_axes,
+ OpBuilder<(ins "llvm::StringRef":$grid,
+ "ArrayRef<GridAxesAttr>":$split_axes,
CArg<"ArrayRef<int64_t>", "{}">:$static_halo_sizes,
CArg<"ArrayRef<int64_t>", "{}">:$static_sharded_dims_offsets
)>,
- OpBuilder<(ins "mlir::mesh::MeshSharding":$from)>
+ OpBuilder<(ins "mlir::shard::Sharding":$from)>
];
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
-def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
+def Shard_GetShardingOp : Shard_Op<"get_sharding", [Pure]> {
let summary = "Get the sharding of the given tensor.";
let description = [{
- This operation returns the sharding of the given tensor as a MeshSharding.
+ This operation returns the sharding of the given tensor as a Sharding.
}];
let arguments = (ins
AnyRankedTensor:$source
);
let results = (outs
- Mesh_Sharding:$result
+ Shard_Sharding:$result
);
let assemblyFormat = [{
$source attr-dict `:` type($source) `->` type($result)
}];
}
-def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [
+def Shard_ShardShapeOp : Shard_Op<"shard_shape", [
Pure, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "Get the shard shape for a given process/device.";
let description = [{
- The device/process id is a multi-index of the device/process in the mesh.
- This operation might be used during spmdization when the shard shape depends
- on (non-constant) values used in `mesh.sharding`.
+ The device/process id is a multi-index of the device/process in the shard.
+ This operation might be used during partition when the shard shape depends
+ on (non-constant) values used in `shard.sharding`.
}];
let arguments = (ins
DenseI64ArrayAttr:$dims,
Variadic<Index>:$dims_dynamic,
- Mesh_Sharding:$sharding,
+ Shard_Sharding:$sharding,
DenseI64ArrayAttr:$device,
Variadic<Index>:$device_dynamic
);
@@ -351,23 +351,23 @@ def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [
];
}
-def Mesh_ShardOp : Mesh_Op<"shard", [
+def Shard_ShardOp : Shard_Op<"shard", [
Pure,
AllTypesMatch<["result", "src"]>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
- let summary = "Annotate on how a tensor is sharded across a mesh.";
+ let summary = "Annotate on how a tensor is sharded across a shard.";
let description = [{
- The mesh.shard operation is designed to specify and guide the sharding
- behavior of a tensor value across a mesh topology. This operation has two
+ The shard.shard operation is designed to specify and guide the sharding
+ behavior of a tensor value across a grid topology. This operation has two
operands and two optional attributes:
1. `input`: This operand represents the tensor value that needs to be
annotated for sharding.
- 2. `sharding`: This attribute is type of `MeshShardingType`, which is the core data
- structure to represent distribution of a tensor on a mesh. it is typically defiend
- by an `mesh.sharding` operation.
+ 2. `sharding`: This attribute is type of `ShardingType`, which is the core data
+ structure to represent distribution of a tensor on a shard. it is typically defined
+ by an `shard.sharding` operation.
3. `annotate_for_users`: A unit attribute addressing the scenario when a
tensor's sharding annotation differs based on its context of use (either as
@@ -378,36 +378,36 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
Example:
```
- func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
- %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
+ func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
+ %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %sharding : tensor<4x8xf32>
...
}
func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
- %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
+ %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
...
}
func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () {
- %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
- %1 = mesh.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
+ %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32>
+ %1 = shard.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32>
...
}
- // The first mesh.shard op applies to %arg0, the second mesh.shard op
- // applies for the operand of op0, the third mesh.shard op applies for the
+ // The first shard.shard op applies to %arg0, the second shard.shard op
+ // applies for the operand of op0, the third shard.shard op applies for the
// operand of op2
func.func @both_result_and_multi_operands_annotated(
%arg0 : tensor<4x8xf32>) -> () {
- %sharding = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %sharding : tensor<4x8xf32>
- %sharding1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
- %1 = mesh.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32>
- %sharding2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding
- %2 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
+ %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %sharding : tensor<4x8xf32>
+ %sharding1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
+ %1 = shard.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32>
+ %sharding2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding
+ %2 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
"op0"(%1) : ...
"op1"(%2) : ...
...
@@ -418,44 +418,44 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
```
func.func @annotate_on_same_result_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
- %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
- %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
- %0 = mesh.shard %arg0 to $sharding1 : tensor<4x8xf32>
- %1 = mesh.shard %0 to sharding2 : tensor<4x8xf32>
+ %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+ %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
+ %0 = shard.shard %arg0 to $sharding1 : tensor<4x8xf32>
+ %1 = shard.shard %0 to sharding2 : tensor<4x8xf32>
...
}
func.func @annotate_on_same_result_same_value_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
- %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
- %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %sharding1 : tensor<4x8xf32>
- %1 = mesh.shard %arg0 to %sharding2 : tensor<4x8xf32>
+ %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+ %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
+ %0 = shard.shard %arg0 to %sharding1 : tensor<4x8xf32>
+ %1 = shard.shard %arg0 to %sharding2 : tensor<4x8xf32>
...
}
func.func @annotate_on_same_operand_with_different_sharding(
%arg0 : tensor<4x8xf32>) -> () {
- %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
- %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
- %1 = mesh.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
+ %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+ %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
+ %0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
+ %1 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32>
...
}
func.func @result_annotated_after_operand(
%arg0 : tensor<4x8xf32>) -> () {
- %sharding1 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
- %sharding2 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
- %1 = mesh.shard %0 to %sharding2 : tensor<4x8xf32>
+ %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+ %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
+ %0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32>
+ %1 = shard.shard %0 to %sharding2 : tensor<4x8xf32>
...
}
```
}];
let arguments = (ins
AnyRankedTensor:$src,
- Mesh_Sharding:$sharding,
+ Shard_Sharding:$sharding,
UnitAttr:$annotate_for_users
);
let results = (outs
@@ -473,34 +473,34 @@ def Mesh_ShardOp : Mesh_Op<"shard", [
// collective communication ops
//===----------------------------------------------------------------------===//
-class Mesh_CollectiveCommunicationOpBase<
+class Shard_CollectiveCommunicationOpBase<
string mnemonic, list<Trait> traits = []> :
- Mesh_Op<mnemonic,
+ Shard_Op<mnemonic,
!listconcat(traits,
[
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
])> {
dag commonArgs = (ins
- FlatSymbolRefAttr:$mesh,
- DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
+ FlatSymbolRefAttr:$grid,
+ DefaultValuedAttr<Shard_GridAxesAttr, "{}">:$grid_axes
);
}
-def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultRank,
]> {
- let summary = "All-gather over a device mesh.";
+ let summary = "All-gather over a device grid.";
let description = [{
Gathers along the `gather_axis` tensor axis.
Example:
```mlir
- mesh.mesh @mesh0(shape = 2x2)
+ shard.grid @grid0(shape = 2x2)
...
- %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
+ %1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 1
: tensor<2x2xi8> -> tensor<2x4xi8>
```
Input:
@@ -535,16 +535,16 @@ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
AnyNon0RankedTensor:$result
);
let assemblyFormat = [{
- $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
+ $input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
}
-def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+def Shard_AllReduceOp : Shard_CollectiveCommunicationOpBase<"all_reduce", [
Pure,
SameOperandsAndResultShape]> {
- let summary = "All-reduce over a device mesh.";
+ let summary = "All-reduce over a device grid.";
let description = [{
The accumulation element type is specified by the result type and
it does not need to match the input element type.
@@ -556,34 +556,34 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
Example:
```
- %1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
+ %1 = shard.all_reduce %0 on @grid0 grid_axes = [1, 0] reduction = <max>
: tensor<3x4xf32> -> tensor<3x4xf64>
```
}];
let arguments = !con(commonArgs, (ins
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
- DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
+ DefaultValuedAttr<Shard_ReductionKindAttr, "::mlir::shard::ReductionKind::Sum">:$reduction
));
let results = (outs
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
);
let assemblyFormat = [{
- $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
+ $input `on` $grid (`grid_axes` `=` $grid_axes^)? (`reduction` `=` $reduction^)?
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
let builders = [
- OpBuilder<(ins "Value":$input, "StringRef":$mesh,
- "ArrayRef<MeshAxis>":$meshAxes, "ReductionKind":$reduction)>
+ OpBuilder<(ins "Value":$input, "StringRef":$grid,
+ "ArrayRef<GridAxis>":$gridAxes, "ReductionKind":$reduction)>
];
}
-def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
+def Shard_AllSliceOp : Shard_CollectiveCommunicationOpBase<"all_slice", [
Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultRank
]> {
- let summary = "All-slice over a device mesh. This is the inverse of all-gather.";
+ let summary = "All-slice over a device grid. This is the inverse of all-gather.";
let description = [{
Slice along the `slice_axis` tensor axis.
This operation can be thought of as the inverse of all-gather.
@@ -593,9 +593,9 @@ def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
Example:
```mlir
- mesh.mesh @mesh0(shape = 2x2)
+ shard.grid @grid0(shape = 2x2)
...
- %1 = mesh.all_slice %0 on @mesh0 mesh_axes = [1] slice_axis = 1
+ %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1
: tensor<2x4xi8> -> tensor<2x2xi8>
```
Input:
@@ -630,30 +630,30 @@ def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
AnyNon0RankedTensor:$result
);
let assemblyFormat = [{
- $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis
+ $input `on` $grid (`grid_axes` `=` $grid_axes^)? `slice_axis` `=` $slice_axis
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
let builders = [
- OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>,
- OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>
+ OpBuilder<(ins "Value":$input, "GridOp":$grid, "ArrayRef<GridAxis>":$gridAxes, "int64_t":$sliceAxis)>,
+ OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$grid, "ArrayRef<GridAxis>":$gridAxes, "int64_t":$sliceAxis)>
];
}
-def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+def Shard_AllToAllOp : Shard_CollectiveCommunicationOpBase<"all_to_all", [
Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultRank]> {
- let summary = "All-to-all over a device mesh.";
+ let summary = "All-to-all over a device grid.";
let description = [{
Performs an all-to-all on tensor pieces split along `split_axis`.
The resulting pieces are concatenated along `concat_axis` on ech device.
Example:
```
- mesh.mesh @mesh0(shape = 3)
+ shard.grid @grid0(shape = 3)
...
- %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
+ %1 = shard.all_to_all %0 on @grid0 grid_axes = [0]
split_axis = 0 concat_axis = 0
: tensor<3x2xi8> -> tensor<3x2xi8>
```
@@ -687,7 +687,7 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
AnyNon0RankedTensor:$result
);
let assemblyFormat = [{
- $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+ $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`split_axis` `=` $split_axis
`concat_axis` `=` $concat_axis
attr-dict `:` type($input) `->` type($result)
@@ -695,24 +695,24 @@ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
let hasCanonicalizer = 1;
}
-def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
+def Shard_BroadcastOp : Shard_CollectiveCommunicationOpBase<"broadcast", [
Pure,
AllShapesMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
- let summary = "Broadcast over a device mesh.";
+ let summary = "Broadcast over a device grid.";
let description = [{
Broadcast the tensor on `root` to all devices in each respective group.
- The operation broadcasts along mesh axes `mesh_axes`.
+ The operation broadcasts along grid axes `grid_axes`.
The `root` device specifies the in-group multi-index that is broadcast to
all other devices in the group.
Example:
```
- mesh.mesh @mesh0(shape = 2x2)
+ shard.grid @grid0(shape = 2x2)
- %1 = mesh.broadcast %0 on @mesh0
- mesh_axes = [0]
+ %1 = shard.broadcast %0 on @grid0
+ grid_axes = [0]
root = [0]
: (tensor<2xi8>) -> tensor<2xi8>
```
@@ -744,31 +744,31 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
AnyRankedTensor:$result
);
let assemblyFormat = [{
- $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+ $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
let hasCanonicalizer = 1;
}
-def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
+def Shard_GatherOp : Shard_CollectiveCommunicationOpBase<"gather", [
Pure,
AllRanksMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
- let summary = "Gather over a device mesh.";
+ let summary = "Gather over a device grid.";
let description = [{
Gathers on device `root` along the `gather_axis` tensor axis.
- `root` specifies the coordinates of a device along `mesh_axes`.
+ `root` specifies the coordinates of a device along `grid_axes`.
It uniquely identifies the root device for each device group.
The result tensor on non-root devices is undefined.
Using it will result in undefined behavior.
Example:
```mlir
- mesh.mesh @mesh0(shape = 2x2)
+ shard.grid @grid0(shape = 2x2)
...
- %1 = mesh.gather %0 on @mesh0 mesh_axes = [1]
+ %1 = shard.gather %0 on @grid0 grid_axes = [1]
gather_axis = 1 root = [1]
: (tensor<2x2xi8>) -> tensor<2x4xi8>
```
@@ -807,7 +807,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
AnyNon0RankedTensor:$result
);
let assemblyFormat = [{
- $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+ $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`gather_axis` `=` $gather_axis
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
@@ -815,11 +815,11 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
let hasCanonicalizer = 1;
}
-def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
+def Shard_RecvOp : Shard_CollectiveCommunicationOpBase<"recv", [
AllShapesMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
- let summary = "Send over a device mesh.";
+ let summary = "Send over a device grid.";
let description = [{
Receive from a device within a device group.
}];
@@ -832,21 +832,21 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
AnyRankedTensor:$result
);
let assemblyFormat = [{
- $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+ $input `on` $grid (`grid_axes` `=` $grid_axes^)?
(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
attr-dict `:` functional-type(operands, results)
}];
let hasCanonicalizer = 1;
}
-def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
+def Shard_ReduceOp : Shard_CollectiveCommunicationOpBase<"reduce", [
Pure,
AllShapesMatch<["input", "result"]>
]> {
- let summary = "Reduce over a device mesh.";
+ let summary = "Reduce over a device grid.";
let description = [{
Reduces on device `root` within each device group.
- `root` specifies the coordinates of a device along `mesh_axes`.
+ `root` specifies the coordinates of a device along `grid_axes`.
It uniquely identifies the root device within its device group.
The accumulation element type is specified by the result type and
it does not need to match the input element type.
@@ -858,14 +858,14 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
Example:
```
- %1 = mesh.reduce %0 on @mesh0 mesh_axes = [1, 0]
+ %1 = shard.reduce %0 on @grid0 grid_axes = [1, 0]
reduction = <max> root = [2, 3]
: (tensor<3x4xf32>) -> tensor<3x4xf64>
```
}];
let arguments = !con(commonArgs, (ins
AnyRankedTensor:$input,
- DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
+ DefaultValuedAttr<Shard_ReductionKindAttr, "::mlir::shard::ReductionKind::Sum">:$reduction,
DenseI64ArrayAttr:$root,
Variadic<Index>:$root_dynamic
));
@@ -873,7 +873,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
AnyRankedTensor:$result
);
let assemblyFormat = [{
- $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+ $input `on` $grid (`grid_axes` `=` $grid_axes^)?
(`reduction` `=` $reduction^)?
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
@@ -881,19 +881,19 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
let hasCanonicalizer = 1;
}
-def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter", [
Pure,
SameOperandsAndResultRank]> {
- let summary = "Reduce-scatter over a device mesh.";
+ let summary = "Reduce-scatter over a device grid.";
let description = [{
After the reduction, the result is scattered within each device group.
The tensor is split along `scatter_axis` and the pieces distributed
across the device group.
Example:
```
- mesh.mesh @mesh0(shape = 2x2)
+ shard.grid @grid0(shape = 2x2)
...
- %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
+ %1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
reduction = <max> scatter_axis = 0
: tensor<3x4xf32> -> tensor<1x4xf64>
```
@@ -928,14 +928,14 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
- DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction,
+ DefaultValuedAttr<Shard_ReductionKindAttr, "::mlir::shard::ReductionKind::Sum">:$reduction,
IndexAttr:$scatter_axis
));
let results = (outs
AnyRankedTensor:$result
);
let assemblyFormat = [{
- $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+ $input `on` $grid (`grid_axes` `=` $grid_axes^)?
(`reduction` `=` $reduction^)?
`scatter_axis` `=` $scatter_axis
attr-dict `:` type($input) `->` type($result)
@@ -943,20 +943,20 @@ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter",
let hasCanonicalizer = 1;
}
-def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
+def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
Pure,
AllRanksMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
- let summary = "Scatter over a device mesh.";
+ let summary = "Scatter over a device grid.";
let description = [{
For each device group split the input tensor on the `root` device along
axis `scatter_axis` and scatter the parts across the group devices.
Example:
```
- mesh.mesh @mesh0(shape = 2x2)
- %1 = mesh.scatter %0 on @mesh0 mesh_axes = [0]
+ shard.grid @grid0(shape = 2x2)
+ %1 = shard.scatter %0 on @grid0 grid_axes = [0]
scatter_axis = 0
root = [1]
: (tensor<2x2xi8>) -> tensor<1x2xi8>
@@ -1004,7 +1004,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
AnyRankedTensor:$result
);
let assemblyFormat = [{
- $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+ $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`scatter_axis` `=` $scatter_axis
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
@@ -1012,11 +1012,11 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
let hasCanonicalizer = 1;
}
-def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
+def Shard_SendOp : Shard_CollectiveCommunicationOpBase<"send", [
AllShapesMatch<["input", "result"]>,
AllElementTypesMatch<["input", "result"]>
]> {
- let summary = "Send over a device mesh.";
+ let summary = "Send over a device grid.";
let description = [{
Send from one device to another within a device group.
}];
@@ -1029,38 +1029,38 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
AnyRankedTensor:$result
);
let assemblyFormat = [{
- $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+ $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
attr-dict `:` functional-type(operands, results)
}];
let hasCanonicalizer = 1;
}
-def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
+def Shard_ShiftOp : Shard_CollectiveCommunicationOpBase<"shift", [
Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultShape
]> {
- let summary = "Shift over a device mesh.";
+ let summary = "Shift over a device grid.";
let description = [{
- Within each device group shift along mesh axis `shift_axis` by an offset
+ Within each device group shift along grid axis `shift_axis` by an offset
`offset`.
The result on devices that do not have a corresponding source is undefined.
- `shift_axis` must be one of `mesh_axes`.
+ `shift_axis` must be one of `grid_axes`.
If the `rotate` attribute is present,
instead of a shift a rotation is done.
Example:
```
- mesh.mesh @mesh0(shape = 2x4)
- %1 = mesh.shift on @mesh0 mesh_axes = [1]
+ shard.grid @grid0(shape = 2x4)
+ %1 = shard.shift on @grid0 grid_axes = [1]
shift_axis = 1 offset = 2 rotate
: tensor<2xi8> -> tensor<2xi8>
```
Input:
```
- mesh axis 1
+ grid axis 1
----------->
+----+----+----+----+
@@ -1089,7 +1089,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
AnyRankedTensor:$result
);
let assemblyFormat = [{
- $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
+ $input `on` $grid (`grid_axes` `=` $grid_axes^)?
`shift_axis` `=` $shift_axis
`offset` `=` $offset
(`rotate` $rotate^)?
@@ -1098,7 +1098,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
let hasCanonicalizer = 1;
}
-def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
+def Shard_UpdateHaloOp : Shard_Op<"update_halo", [
Pure,
DestinationStyleOpInterface,
TypesMatchWith<
@@ -1120,14 +1120,14 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
`destination_halo_sizes/static_destination_halo_sizes` in source shard
and destination/result shard.
- `split_axes` specifies for each tensor axis along which mesh axes its halo
+ `split_axes` specifies for each tensor axis along which grid axes its halo
data is updated.
}];
let arguments = (ins
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
- FlatSymbolRefAttr:$mesh,
- Mesh_MeshAxesArrayAttr:$split_axes,
+ FlatSymbolRefAttr:$grid,
+ Shard_GridAxesArrayAttr:$split_axes,
Variadic<I64>:$halo_sizes,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes
);
@@ -1136,7 +1136,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
);
let assemblyFormat = [{
$destination
- `on` $mesh
+ `on` $grid
`split_axes` `=` $split_axes
(`halo_sizes` `=` custom<DynamicIndexList>($halo_sizes, $static_halo_sizes)^)?
attr-dict `:` type($result)
@@ -1145,4 +1145,4 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
}];
}
-#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
+#endif // MLIR_DIALECT_SHARD_IR_SHARDOPS_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/Interfaces/CMakeLists.txt
similarity index 100%
rename from mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
rename to mlir/include/mlir/Dialect/Shard/Interfaces/CMakeLists.txt
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h
similarity index 52%
rename from mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
rename to mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h
index 14aad7f9f6783..55de06b137e8d 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
-#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+#ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_
+#define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
@@ -20,24 +20,24 @@ class Operation;
class IRMapping;
class SymbolTableCollection;
-namespace mesh {
+namespace shard {
-using ShardingArray = SmallVector<SmallVector<MeshAxis>>;
-using ShardingArrayRef = ArrayRef<SmallVector<MeshAxis>>;
+using ShardingArray = SmallVector<SmallVector<GridAxis>>;
+using ShardingArrayRef = ArrayRef<SmallVector<GridAxis>>;
struct ShardingOption {
// An array of int array. The sub-array at the i-th position signifies the
- // mesh axes the i-th loop will be sharded on.
+ // grid axes the i-th loop will be sharded on.
ShardingArray shardingArray = {};
- FlatSymbolRefAttr mesh = nullptr;
+ FlatSymbolRefAttr grid = nullptr;
// `empty` being true indicates that no sharding information can be inferred
// at present. Note that it is different from the case where an operation is
// not sharded.
bool empty = false;
ShardingOption() = default;
- ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
- : shardingArray(std::move(shardingArray)), mesh(mesh) {
- assert(this->mesh);
+ ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr grid)
+ : shardingArray(std::move(shardingArray)), grid(grid) {
+ assert(this->grid);
}
static ShardingOption makeEmpty() {
auto res = ShardingOption();
@@ -46,21 +46,21 @@ struct ShardingOption {
}
};
-// This method retrieves the 'MeshSharding' from a given operation
+// This method retrieves the 'Sharding' from a given operation
// result and includes the 'annotate_for_users' information.
-FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpResult result);
+FailureOr<std::pair<bool, Sharding>> getSharding(OpResult result);
-// This method retrieves the 'MeshSharding' from a given operation
+// This method retrieves the 'Sharding' from a given operation
// operand and includes the 'annotate_for_users' information.
-FailureOr<std::pair<bool, MeshSharding>> getMeshSharding(OpOperand &opOperand);
+FailureOr<std::pair<bool, Sharding>> getSharding(OpOperand &opOperand);
namespace detail {
FailureOr<ShardingOption>
-defaultGetShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings);
+defaultGetShardingOption(Operation *op, ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings);
-FailureOr<std::vector<MeshSharding>>
+FailureOr<std::vector<Sharding>>
defaultGetShardingAnnotations(Operation *op,
const ShardingOption &shardingOption);
@@ -71,18 +71,18 @@ defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
} // namespace detail
// Assumes full replication on all ranked tensor arguments and results.
-void spmdizeFullyReplicatedOperation(Operation &op,
- ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
- IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable,
- OpBuilder &builder);
-
-} // namespace mesh
+void partitionFullyReplicatedOperation(Operation &op,
+ ArrayRef<Value> partitiondOperands,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder);
+
+} // namespace shard
} // namespace mlir
/// Include the ODS generated interface header files.
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h.inc"
-#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+#endif // MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.td
similarity index 80%
rename from mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
rename to mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.td
index a70d2c3e03851..34b0813938b61 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterface.td
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
-#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_TD
+#define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_TD
include "mlir/IR/OpBase.td"
@@ -16,7 +16,7 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
Interface for allowing operations to expose information needed to
shard them.
}];
- let cppNamespace = "::mlir::mesh";
+ let cppNamespace = "::mlir::shard";
let methods = [
InterfaceMethod<
@@ -84,8 +84,8 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
/*retTy=*/"FailureOr<ShardingOption>",
/*methodName=*/"getShardingOption",
/*args=*/(ins
- "ArrayRef<MeshSharding>": $operandShardings,
- "ArrayRef<MeshSharding>": $resultShardings
+ "ArrayRef<Sharding>": $operandShardings,
+ "ArrayRef<Sharding>": $resultShardings
),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -100,7 +100,7 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
This is what shardings the operands and results need to have in order
to shard the op according to shardingOption.
}],
- /*retTy=*/"FailureOr<std::vector<MeshSharding>>",
+ /*retTy=*/"FailureOr<std::vector<Sharding>>",
/*methodName=*/"getShardingAnnotations",
/*args=*/(ins
"const ShardingOption &":$shardingOption
@@ -113,7 +113,7 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
>,
InterfaceMethod<
/*desc=*/[{
- Based on a given ShardingOption, this method adds `mesh.shard`
+ Based on a given ShardingOption, this method adds `shard.shard`
operations for the operands and results that previously lacked
sharding annotations.
}],
@@ -132,21 +132,21 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
InterfaceMethod<
/*desc=*/[{
Convert self to SPMD form.
- This method is used during the spmdization pass of a program fully
+ This method is used during the partition pass of a program fully
annotated with shardings.
- The spmdization algorithm would read the surrounding sharding
+ The partition algorithm would read the surrounding sharding
annotations from the IR for each argument/result and prepare
`operandShardings` and `resultShardings`.
Values that are not ranked tensors do not have sharding annotations.
- In this case their corresponding MeshSharding is null.
+ In this case their corresponding Sharding is null.
- For convenience it will also prepare `spmdizedOperands`, although
- they can be retrieved from the `spmdizationMap`.
+ For convenience it will also prepare `partitiondOperands`, although
+ they can be retrieved from the `partitionMap`.
- The `spmdizationMap` contains a mapping from unsharded to
- sharded/spmdized values that are constructed during the spmdization
- pass. The interface implementation must populate `spmdizationMap`
+ The `partitionMap` contains a mapping from unsharded to
+ sharded/partitiond values that are constructed during the partition
+ pass. The interface implementation must populate `partitionMap`
with the mapping for this op's results.
`builder` is set to insert new operations in the appropriate point.
@@ -158,20 +158,20 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
This assumes that all sharding annotations are for full replication.
}],
/*retTy=*/"LogicalResult",
- /*methodName=*/"spmdize",
+ /*methodName=*/"partition",
/*args=*/(ins
- "ArrayRef<Value>": $spmdizedOperands,
- "ArrayRef<MeshSharding>": $operandShardings,
- "ArrayRef<MeshSharding>": $resultShardings,
- "IRMapping&": $spmdizationMap,
+ "ArrayRef<Value>": $partitiondOperands,
+ "ArrayRef<Sharding>": $operandShardings,
+ "ArrayRef<Sharding>": $resultShardings,
+ "IRMapping&": $partitionMap,
"SymbolTableCollection &": $symbolTableCollection,
"OpBuilder &":$builder
),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- spmdizeFullyReplicatedOperation(
- *$_op.getOperation(), spmdizedOperands, operandShardings,
- resultShardings, spmdizationMap, symbolTableCollection, builder);
+ partitionFullyReplicatedOperation(
+ *$_op.getOperation(), partitiondOperands, operandShardings,
+ resultShardings, partitionMap, symbolTableCollection, builder);
return success();
}]>
];
@@ -184,4 +184,4 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
}
-#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#endif // MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h
similarity index 58%
rename from mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
rename to mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h
index 2af8b2bd1d906..6f047246aca6d 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h
@@ -6,11 +6,11 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
-#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
+#ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Value.h"
@@ -20,35 +20,34 @@ class Operation;
class IRMapping;
class SymbolTableCollection;
-namespace mesh {
+namespace shard {
-// Retrieve the mesh axes corresponding to each operation loop iterator based
+// Retrieve the grid axes corresponding to each operation loop iterator based
// on the provided shardings for the op's operands and results.
// Assumes that the indexingMaps are projected permutations.
-ShardingArray getMeshAxisAssignmentForLoopIterators(
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
+ShardingArray getGridAxisAssignmentForLoopIterators(
+ ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<AffineMap> indexingMaps);
bool isAtLeastOneReductionIteratorSharded(
ArrayRef<utils::IteratorType> loopIteratorTypes,
- ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
+ ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators);
-// Get the set of mesh axes that correspond to reduction loop iterators.
-SmallVector<MeshAxis> getReductionMeshAxes(
+// Get the set of grid axes that correspond to reduction loop iterators.
+SmallVector<GridAxis> getReductionGridAxes(
ArrayRef<utils::IteratorType> loopIteratorTypes,
- ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
+ ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators);
// Inserts a clone of the operation that has all ranked tensor
// arguments/results sharded.
-void spmdizeTriviallyShardableOperation(Operation &op,
- ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
- IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable,
- OpBuilder &builder);
+void partitionTriviallyShardableOperation(Operation &op,
+ ArrayRef<Value> partitiondOperands,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder);
// All ranked tensor argument and result dimensions have
// independent parallel loop iterators.
@@ -73,15 +72,15 @@ struct IndependentParallelIteratorDomainShardingInterface
return SmallVector<AffineMap>();
}
- LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
- IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable,
- OpBuilder &builder) const {
- spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
- resultShardings, spmdizationMap,
- symbolTable, builder);
+ LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ partitionTriviallyShardableOperation(*op, partitiondOperands,
+ operandShardings, resultShardings,
+ partitionMap, symbolTable, builder);
return success();
}
@@ -129,20 +128,20 @@ struct ElementwiseShardingInterface
return maps;
}
- LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
- IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable,
- OpBuilder &builder) const {
- spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
- resultShardings, spmdizationMap,
- symbolTable, builder);
+ LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ partitionTriviallyShardableOperation(*op, partitiondOperands,
+ operandShardings, resultShardings,
+ partitionMap, symbolTable, builder);
return success();
}
};
-} // namespace mesh
+} // namespace shard
} // namespace mlir
-#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
+#endif // MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt
new file mode 100644
index 0000000000000..9e2c8d00b27f5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Shard)
+add_public_tablegen_target(MLIRShardPassIncGen)
+add_dependencies(mlir-headers MLIRShardPassIncGen)
+
+add_mlir_doc(Passes ShardPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
similarity index 61%
rename from mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
rename to mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
index 2f6de3e134319..37903765903db 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
@@ -1,4 +1,4 @@
-//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,35 +6,35 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
-#define MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_PARTITION_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_PARTITION_H
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/IR/DialectRegistry.h"
namespace mlir {
-namespace mesh {
+namespace shard {
-// Insert resharding spmdization of the value `sourceShardValue`
+// Insert resharding partition of the value `sourceShardValue`
// from sharding `source` to sharding `target`.
// `sourceShardValue` is the already sharded value according to `source`.
//
// Example
//
// ```mlir
-// mesh.mesh @mesh_1d(shape = 2)
+// shard.grid @grid_1d(shape = 2)
// ...
-// %1 = mesh.shard %0 to <@mesh_1d, [[0]]> : tensor<2xi8>
-// %2 = mesh.shard %1 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+// %1 = shard.shard %0 to <@grid_1d, [[0]]> : tensor<2xi8>
+// %2 = shard.shard %1 to <@grid_1d, [[]]> annotate_for_users: tensor<2xi8>
// ```
//
// Will result in
//
// ```mlir
-// %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 :
+// %1 = shard.all_gather %0 on @grid_1d grid_axes = [0] gather_axis = 0 :
// tensor<1xi8> -> tensor<2xi8>
// ```
-TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
+TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
ShardOp target,
TypedValue<ShapedType> sourceShardValue);
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
@@ -44,7 +44,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
void reshardingRegisterDependentDialects(DialectRegistry ®istry);
-} // namespace mesh
+} // namespace shard
} // namespace mlir
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_PARTITION_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.h
similarity index 75%
rename from mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
rename to mlir/include/mlir/Dialect/Shard/Transforms/Passes.h
index a2424d43a8ba9..88bb460255728 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.h
@@ -1,4 +1,4 @@
-//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
+//===- Passes.h - Shard Passes ----------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
-#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_H
#include "mlir/Pass/Pass.h"
@@ -17,7 +17,7 @@ namespace func {
class FuncOp;
}
-namespace mesh {
+namespace shard {
/// This enum controls the traversal order for the sharding propagation.
enum class TraversalOrder {
@@ -36,16 +36,16 @@ enum class TraversalOrder {
//===----------------------------------------------------------------------===//
#define GEN_PASS_DECL
-#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
#define GEN_PASS_REGISTRATION
-#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
-} // namespace mesh
+} // namespace shard
} // namespace mlir
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
similarity index 65%
rename from mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
rename to mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
index 11ec7e78cd5e6..bbc6a1977b13e 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
@@ -1,4 +1,4 @@
-//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
+//===-- Passes.td - Shard transformation definition file ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
-#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_TD
include "mlir/Pass/PassBase.td"
@@ -20,31 +20,31 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
let summary = "sharding propagation";
let description = [{
Propagates sharding information throughout the graph. After this pass, each
- of the operations' operands and results is annotated with a `mesh.shard`
+ of the operations' operands and results is annotated with a `shard.shard`
operation, and the operations themselves are added with sharding option
attributes.
}];
let options = [
Option<"traversal", "traversal",
- "mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
+ "mlir::shard::TraversalOrder", /*default=*/"mlir::shard::TraversalOrder::BackwardForward",
"Traversal order to use for sharding propagation:",
[{::llvm::cl::values(
- clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
+ clEnumValN(mlir::shard::TraversalOrder::Forward, "forward",
"Forward only traversal."),
- clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
+ clEnumValN(mlir::shard::TraversalOrder::Backward, "backward",
"backward only traversal."),
- clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
+ clEnumValN(mlir::shard::TraversalOrder::ForwardBackward, "forward-backward",
"forward-backward traversal."),
- clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
+ clEnumValN(mlir::shard::TraversalOrder::BackwardForward, "backward-forward",
"backward-forward traversal.")
)}]>,
];
let dependentDialects = [
- "mesh::MeshDialect"
+ "shard::ShardDialect"
];
}
-def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> {
+def Partition : InterfacePass<"shard-partition", "mlir::FunctionOpInterface"> {
let summary = "Partition a function into SPMD form.";
let description = [{
This pass fits in right after a pass that annotates the function with
@@ -52,15 +52,15 @@ def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface">
It operates on a fully annotated IR.
A fully annotated IR required that all ranked tensor operands, results and
- block arguments are annotated with the `mesh.shard` operation.
+ block arguments are annotated with the `shard.shard` operation.
All direct descendant operations in the function must implement the
`ShardingInterface` interface or all their ranked tensor operands and
results must have full replication sharding.
The input IR must have sharding annotations such that each operation
- that implements `ShardingInterface` can handle during spmdization with
- its `spmdize` method.
+ that implements `ShardingInterface` can handle during partition with
+ its `partition` method.
This can be achieved with the `ShardingPropagation` pass.
If the function has multiple terminating blocks,
@@ -70,36 +70,36 @@ def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface">
Example:
```mlir
- mesh.mesh @mesh_1d(shape = 2)
+ shard.grid @grid_1d(shape = 2)
func.func @f(
%arg0: tensor<2xi8>
) -> tensor<2xi8> {
- %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xi8>
- %1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %0 = shard.shard %arg0 to <@grid_1d, [[0]]> : tensor<2xi8>
+ %1 = shard.shard %0 to <@grid_1d, [[0]]> annotate_for_users: tensor<2xi8>
%2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
- %3 = mesh.shard %2 to <@mesh_1d, [[0]]> : tensor<2xi8>
- %4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
+ %3 = shard.shard %2 to <@grid_1d, [[0]]> : tensor<2xi8>
+ %4 = shard.shard %3 to <@grid_1d, [[]]> annotate_for_users: tensor<2xi8>
return %4 : tensor<2xi8>
}
```
- Spmdizing the above would result in
+ Partitioning the above would result in
* Performing the element-wise `abs` operation on each device.
* Resharding to full replication with an all-gather.
```mlir
- mesh.mesh @mesh_1d(shape = 2)
+ shard.grid @grid_1d(shape = 2)
func.func @f(%arg0: tensor<1xi8>) -> tensor<2xi8> {
%0 = tosa.abs %arg0 : (tensor<1xi8>) -> tensor<1xi8>
- %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+ %1 = shard.all_gather %0 on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
return %1 : tensor<2xi8>
}
```
}];
let dependentDialects = [
- "mesh::MeshDialect"
+ "shard::ShardDialect"
];
}
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md b/mlir/include/mlir/Dialect/Shard/Transforms/ReshardingPartitionDoc.md
similarity index 87%
rename from mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md
rename to mlir/include/mlir/Dialect/Shard/Transforms/ReshardingPartitionDoc.md
index 6368931cf6e07..cf5ae12b54b2c 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/ReshardingPartitionDoc.md
@@ -1,6 +1,6 @@
-# Resharding Spmdization Examples
+# Resharding Partition Examples
-Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[0, 1]]` on a `2x3` mesh.
+Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[0, 1]]` on a `2x3` shard.
unsharded `2x3` tensor
```
@@ -8,16 +8,16 @@ unsharded `2x3` tensor
21 22 23
```
-sharded on a `2x3` mesh
+sharded on a `2x3` grid
sharding = `[[0, 1]]`
-mesh contents:
+grid contents:
```
-mesh axis 1
+grid axis 1
----------->
-+----+----+----+ mesh axis 0 |
++----+----+----+ grid axis 0 |
| 11 | 12 | 13 | |
+----+----+----+ |
| 21 | 22 | 23 | |
@@ -27,9 +27,9 @@ mesh axis 1
Transform into
sharding = `[[1, 0]]`
```
-mesh axis 1
+grid axis 1
----------->
-+----+----+----+ mesh axis 0 |
++----+----+----+ grid axis 0 |
| 11 | 13 | 22 | |
+----+----+----+ |
| 12 | 21 | 23 | |
@@ -40,7 +40,7 @@ Swap contents on devices that have the same linear index in the 2 shardings.
--------------------------------------------------------------
-Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[1]]` on a `2x3` mesh.
+Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[1]]` on a `2x3` shard.
unsharded `2x3` tensor
```
@@ -48,15 +48,15 @@ unsharded `2x3` tensor
21 22 23
```
-sharded on a `2x3` mesh
+sharded on a `2x3` grid
sharding = `[[0, 1]]`
-mesh contents:
+grid contents:
```
-mesh axis 1
+grid axis 1
----------->
-+----+----+----+ mesh axis 0 |
++----+----+----+ grid axis 0 |
| 11 | 12 | 13 | |
+----+----+----+ |
| 21 | 22 | 23 | |
@@ -66,9 +66,9 @@ mesh axis 1
Transform into
sharding = `[[1]]`
```
-mesh axis 1
+grid axis 1
----------->
-+----+----+----+ mesh axis 0 |
++----+----+----+ grid axis 0 |
| 11 | 12 | 13 | |
| 21 | 22 | 23 | |
+----+----+----+ |
@@ -77,11 +77,11 @@ mesh axis 1
+----+----+----+ ↓
```
Algorithm:
-All-gather along mesh axis 0.
+All-gather along grid axis 0.
--------------------------------------------------------------
-Reshard `4x6` tensor from sharding `[[], [0, 1]]` to sharding `[[], [0]]` on a `2x3` mesh.
+Reshard `4x6` tensor from sharding `[[], [0, 1]]` to sharding `[[], [0]]` on a `2x3` shard.
unsharded `4x6` tensor
```
@@ -89,15 +89,15 @@ unsharded `4x6` tensor
21 22 23 24 25 26
```
-sharded on a `2x3` mesh
+sharded on a `2x3` grid
sharding = `[[], [0, 1]]`
-mesh contents:
+grid contents:
```
-mesh axis 1
+grid axis 1
----------->
-+----+----+----+ mesh axis 0 |
++----+----+----+ grid axis 0 |
| 11 | 12 | 13 | |
| 21 | 22 | 23 | |
+----+----+----+ |
@@ -108,9 +108,9 @@ mesh axis 1
Transform into
sharding = `[[], [0]]`
```
-mesh axis 1
+grid axis 1
----------->
-+----------+----------+ mesh axis 0 |
++----------+----------+ grid axis 0 |
| 11 12 13 | 11 12 13 | |
| 21 22 23 | 21 22 23 | |
+----------+----------+ |
@@ -119,11 +119,11 @@ mesh axis 1
+----------+----------+ ↓
```
Algorithm:
-All-gather along mesh axis 1.
+All-gather along grid axis 1.
--------------------------------------------------------------
-Reshard `4x8` tensor from sharding `[[0], [1, 2]]` to sharding `[[0], [2]]` on a `2x2x2` mesh.
+Reshard `4x8` tensor from sharding `[[0], [1, 2]]` to sharding `[[0], [2]]` on a `2x2x2` shard.
unsharded `4x8` tensor
```
@@ -132,15 +132,15 @@ unsharded `4x8` tensor
31 32 33 34 35 36 37 38
41 42 43 44 45 46 47 48
```
-sharded on a `2x2x2` mesh
+sharded on a `2x2x2` grid
sharding = `[[0], [1, 2]]`
-mesh contents:
+grid contents:
```
-mesh axis 2
+grid axis 2
----------->
-+-------+-------+ mesh axis 1 | mesh axis 0 |
++-------+-------+ grid axis 1 | grid axis 0 |
| 11 12 | 13 14 | | |
| 21 22 | 23 24 | | |
+-------+-------+ | |
@@ -158,9 +158,9 @@ mesh axis 2
Transform into
sharding = `[[0], [2]]`
```
-mesh axis 2
+grid axis 2
----------->
-+-------------+-------------+ mesh axis 1 | mesh axis 0 |
++-------------+-------------+ grid axis 1 | grid axis 0 |
| 11 12 13 14 | 15 16 17 18 | | |
| 21 22 23 24 | 25 26 27 28 | | |
+-------------+-------------+ | |
@@ -177,13 +177,13 @@ mesh axis 2
```
Algorithm:
-Can't be done with just an all-gather along mesh axis 1.
+Can't be done with just an all-gather along grid axis 1.
Can be handled by multiple resharding transformations
`[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]`
--------------------------------------------------------------
-Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` mesh.
+Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` shard.
unsharded `6x6` tensor
```
@@ -194,13 +194,13 @@ unsharded `6x6` tensor
51 52 53 54 55 56
61 62 63 64 65 66
```
-sharded on a `2x3` mesh
+sharded on a `2x3` grid
sharding = `[[0], [1]]`
```
-mesh axis 1
+grid axis 1
----------->
-+-------+-------+-------+ mesh axis 0 |
++-------+-------+-------+ grid axis 0 |
| 11 12 | 13 14 | 15 16 | |
| 21 22 | 23 24 | 25 26 | |
| 31 32 | 33 34 | 35 36 | |
@@ -213,9 +213,9 @@ mesh axis 1
transform to
sharding = `[[1], [0]]`
```
-mesh axis 1
+grid axis 1
----------->
-+----------+----------+----------+ mesh axis 0 |
++----------+----------+----------+ grid axis 0 |
| 11 12 13 | 31 32 33 | 51 52 53 | |
| 21 22 23 | 41 42 43 | 61 62 63 | |
+----------+----------+----------+ |
@@ -223,9 +223,9 @@ mesh axis 1
| 24 25 26 | 44 45 46 | 64 65 66 | |
+----------+----------+----------+ ↓
-mesh axis 0
+grid axis 0
----------->
-+----------+----------+ mesh axis 1 |
++----------+----------+ grid axis 1 |
| 11 12 13 | 14 15 16 | |
| 21 22 23 | 24 25 26 | |
+----------+----------+ |
@@ -240,7 +240,7 @@ Algorithm: TODO
--------------------------------------------------------------
-Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x6` mesh.
+Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x6` shard.
unsharded 6x6 tensor
```
@@ -251,13 +251,13 @@ unsharded 6x6 tensor
51 52 53 54 55 56
61 62 63 64 65 66
```
-shard on `2x6` mesh
+shard on `2x6` grid
sharding = `[[0], [1]]`
```
-mesh axis 1
+grid axis 1
----------->
-+----+----+----+----+----+----+ mesh axis 0 |
++----+----+----+----+----+----+ grid axis 0 |
| 11 | 12 | 13 ‖ 14 | 15 | 16 | |
| 21 | 22 | 23 ‖ 24 | 23 | 26 | |
| 31 | 32 | 33 ‖ 34 | 35 | 36 | |
@@ -270,9 +270,9 @@ mesh axis 1
transform to
sharding = `[[1], [0]]`
```
-mesh axis 0
+grid axis 0
----------->
-+----------+----------+ mesh axis 1 |
++----------+----------+ grid axis 1 |
| 11 12 13 | 14 15 16 | |
+----------+----------+ |
| 21 22 23 | 24 25 26 | |
@@ -290,9 +290,9 @@ Algorithm: TODO
--------------------------------------------------------------
-Reshard KxL tensor from `[[0], [1]]` to `[[1], [0]]` on `MxN` mesh.
+Reshard KxL tensor from `[[0], [1]]` to `[[1], [0]]` on `MxN` shard.
-`M x N` mesh.
+`M x N` shard.
`K x L` tensor `t`.
`d(m, n)` the tensor on device `(m, n)`.
@@ -433,9 +433,9 @@ TODO
--------------------------------------------------------------
-Reshard `KxL` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` mesh.
+Reshard `KxL` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` shard.
-Device placement on a `2x3` mesh
+Device placement on a `2x3` grid
```
11 12 13 <- devices
21 22 23
@@ -512,7 +512,7 @@ TODO
--------------------------------------------------------------
-Reshard `6x6` tensor from sharding `[[0], []]` to sharding `[[], [0]]` on a `3` mesh.
+Reshard `6x6` tensor from sharding `[[0], []]` to sharding `[[], [0]]` on a `3` shard.
unsharded `6x6` tensor
```
@@ -523,11 +523,11 @@ unsharded `6x6` tensor
51 52 53 54 55 56
61 62 63 64 65 66
```
-sharded on a `3` mesh
+sharded on a `3` grid
sharding = `[[0], []]`
```
-+-------------------+ mesh axis 0 |
++-------------------+ grid axis 0 |
| 11 12 13 14 15 16 | |
| 21 22 23 24 25 26 | |
+-------------------+ |
@@ -541,7 +541,7 @@ sharding = `[[0], []]`
transform to
sharding = `[[], [0]]`
```
-mesh axis 0
+grid axis 0
----------->
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
@@ -554,11 +554,11 @@ mesh axis 0
```
Algorithm:
```mlir
-%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
+%1 = all_to_all %0 on @grid grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
```
--------------------------------------------------------------
-Reshard `4x4` tensor from sharding `[[0], [1, 2]]` to sharding `[[0, 1], [2]]` on a `2x2x2` mesh.
+Reshard `4x4` tensor from sharding `[[0], [1, 2]]` to sharding `[[0, 1], [2]]` on a `2x2x2` shard.
unsharded `4x4` tensor
```
@@ -567,13 +567,13 @@ unsharded `4x4` tensor
31 32 33 34
41 42 43 44
```
-sharded on a `2x2x2` mesh
+sharded on a `2x2x2` grid
sharding = `[[0], [1, 2]]`
```
-mesh axis 2
+grid axis 2
----------->
-+----+----+ mesh axis 1 | mesh axis 0 |
++----+----+ grid axis 1 | grid axis 0 |
| 11 | 12 | | |
| 21 | 22 | | |
+----+----+ | |
@@ -591,9 +591,9 @@ mesh axis 2
transform to
sharding = `[[0, 1], [2]]`
```
-mesh axis 2
+grid axis 2
----------->
-+-------+-------+ mesh axis 1 | mesh axis 0 |
++-------+-------+ grid axis 1 | grid axis 0 |
| 11 12 | 13 41 | | |
+-------+-------+ | |
| 21 22 | 23 24 | | |
@@ -606,7 +606,7 @@ mesh axis 2
```
Algorithm:
```mlir
-%1 = all_to_all %0 on @mesh mesh_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8>
+%1 = all_to_all %0 on @grid grid_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8>
```
is not enough.
@@ -639,15 +639,15 @@ Basis:
[[0]] -> [[]]
[[0, 1]] -> [[1]]
```
- All-gather along mesh axis 0.
+ All-gather along grid axis 0.
-* Swap mesh axes order when assigned to the same tensor axis.
+* Swap grid axes order when assigned to the same tensor axis.
```
[[0, 1]] -> [[1, 0]]
```
Swap contents on devices with the same linear index.
-* Move mesh axis to different tensor dimension.
+* Move grid axis to different tensor dimension.
```
[[0], []] -> [[], [0]]
```
@@ -661,9 +661,9 @@ Example decomposition of
```
into
```
-[[0], [1]] -> all-gather along mesh axis 1 ->
-[[0], []] -> all-to-all along mesh axis 0 ->
-[[], [0]] -> extract slice along mesh axis 1 ->
+[[0], [1]] -> all-gather along grid axis 1 ->
+[[0], []] -> all-to-all along grid axis 0 ->
+[[], [0]] -> extract slice along grid axis 1 ->
[[1], [0]]
```
@@ -675,9 +675,9 @@ Example decomposition of
```
into
```
-[[3, 2], [], [0, 1]] -> all-to-all along mesh axis 1 ->
-[[3, 2], [1], [0]] -> all-to-all along mesh axis 2 ->
-[[3], [1, 2], [0]] -> all-gather along mesh axis 3 ->
-[[], [1, 2], [0]] -> all-to-all along mesh axis 0 ->
+[[3, 2], [], [0, 1]] -> all-to-all along grid axis 1 ->
+[[3, 2], [1], [0]] -> all-to-all along grid axis 2 ->
+[[3], [1, 2], [0]] -> all-gather along grid axis 3 ->
+[[], [1, 2], [0]] -> all-to-all along grid axis 0 ->
[[0], [1, 2], []]
```
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h
similarity index 93%
rename from mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
rename to mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h
index 243dbf081b999..452d4f6b4ed61 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h
@@ -1,4 +1,4 @@
-//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
-#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/EndomorphismSimplification.h"
#include "llvm/Support/Casting.h"
@@ -22,7 +22,7 @@ namespace mlir {
class SymbolTableCollection;
-namespace mesh {
+namespace shard {
// If we have an algebraic op like "+" and a summing all-reduce,
// `all_reduce_sum(x) + all_reduce_sum(y)` will be transformed to
@@ -112,7 +112,7 @@ void populateSimplificationPatterns(
void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTableCollection);
-} // namespace mesh
+} // namespace shard
} // namespace mlir
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
similarity index 65%
rename from mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
rename to mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
index f46c0db846088..57d65e687ea35 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
@@ -1,4 +1,4 @@
-//===- Transforms.h - Mesh Transforms ---------------------------*- C++ -*-===//
+//===- Transforms.h - Shard Transforms --------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
-#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMS_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMS_H
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
@@ -20,7 +20,7 @@ class RewritePatternSet;
class SymbolTableCollection;
class DialectRegistry;
class ImplicitLocOpBuilder;
-namespace mesh {
+namespace shard {
void populateProcessMultiIndexOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
@@ -35,20 +35,20 @@ void populateAllOpLoweringPatterns(
void registerAllOpLoweringDialects(DialectRegistry ®istry);
TypedValue<IndexType>
-createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
ImplicitLocOpBuilder &builder);
-// Get process linear index along the given mesh axes.
-TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
- ArrayRef<MeshAxis> meshAxes,
+// Get process linear index along the given grid axes.
+TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
+ ArrayRef<GridAxis> gridAxes,
ImplicitLocOpBuilder &builder);
-// Get process linear index from a multi-index along the given mesh axes .
+// Get process linear index from a multi-index along the given grid axes .
TypedValue<IndexType>
-createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
- ArrayRef<MeshAxis> meshAxes,
+createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
+ ArrayRef<GridAxis> gridAxes,
ImplicitLocOpBuilder &builder);
-} // namespace mesh
+} // namespace shard
} // namespace mlir
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMS_H
diff --git a/mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Tensor/Extensions/ShardingExtensions.h
similarity index 88%
rename from mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h
rename to mlir/include/mlir/Dialect/Tensor/Extensions/ShardingExtensions.h
index cfac485b807f2..895e7e5939935 100644
--- a/mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h
+++ b/mlir/include/mlir/Dialect/Tensor/Extensions/ShardingExtensions.h
@@ -1,4 +1,4 @@
-//===- MeshShardingExtensions.h - -------------------------------*- C++ -*-===//
+//===- ShardingExtensions.h - -------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c6fcf1a0d510b..856170e9308da 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -60,7 +60,6 @@
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -77,6 +76,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -131,7 +131,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
LLVM::LLVMDialect,
math::MathDialect,
memref::MemRefDialect,
- mesh::MeshDialect,
+ shard::ShardDialect,
ml_program::MLProgramDialect,
mpi::MPIDialect,
nvgpu::NVGPUDialect,
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index dd8b292a87344..002ff61fb87dd 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -32,13 +32,13 @@
#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
-#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
#include "mlir/Dialect/Quant/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/Shard/Transforms/Passes.h"
#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
@@ -81,7 +81,7 @@ inline void registerAllPasses() {
LLVM::registerLLVMPasses();
math::registerMathPasses();
memref::registerMemRefPasses();
- mesh::registerMeshPasses();
+ shard::registerShardPasses();
ml_program::registerMLProgramPasses();
quant::registerQuantPasses();
registerSCFPasses();
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index f84375b6b8d6a..785cb8293810c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -43,7 +43,7 @@ add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
-add_subdirectory(MeshToMPI)
+add_subdirectory(ShardToMPI)
add_subdirectory(MPIToLLVM)
add_subdirectory(NVGPUToNVVM)
add_subdirectory(NVVMToLLVM)
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
similarity index 65%
rename from mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
rename to mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
index 15560aa61e145..564f36fd20abb 100644
--- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
+++ b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt
@@ -1,8 +1,8 @@
-add_mlir_conversion_library(MLIRMeshToMPI
- MeshToMPI.cpp
+add_mlir_conversion_library(MLIRShardToMPI
+ ShardToMPI.cpp
ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShardToMPI
DEPENDS
MLIRConversionPassIncGen
@@ -17,7 +17,7 @@ add_mlir_conversion_library(MLIRMeshToMPI
MLIRLinalgTransforms
MLIRMemRefDialect
MLIRPass
- MLIRMeshDialect
+ MLIRShardDialect
MLIRMPIDialect
MLIRTransforms
)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
similarity index 92%
rename from mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
rename to mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 63b1fdabaf407..8525543760d99 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -1,4 +1,4 @@
-//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
+//===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,11 +6,11 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements a translation of Mesh communication ops tp MPI ops.
+// This file implements a translation of Shard communication ops to MPI ops.
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+#include "mlir/Conversion/ShardToMPI/ShardToMPI.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -20,11 +20,11 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
@@ -35,16 +35,16 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#define DEBUG_TYPE "mesh-to-mpi"
+#define DEBUG_TYPE "shard-to-mpi"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace mlir {
-#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
+#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
-using namespace mesh;
+using namespace shard;
namespace {
/// Converts a vector of OpFoldResults (ints) into vector of Values of the
@@ -188,18 +188,18 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
// maxSplitSize+1}. Store the offsets in the tensor but set trailing
// elements for smaller split-groups to -1. Computing the max size of the
// split groups needs using collectiveProcessGroupSize (which needs the
- // MeshOp)
+ // GridOp)
Value resOffsets;
if (adaptor.getStaticShardedDimsOffsets().empty()) {
resOffsets = tensor::EmptyOp::create(rewriter, loc,
std::array<int64_t, 2>{0, 0}, i64);
} else {
SymbolTableCollection symbolTableCollection;
- auto meshOp = getMesh(op, symbolTableCollection);
+ auto gridOp = getGrid(op, symbolTableCollection);
int64_t maxSplitSize = 0;
for (auto axes : splitAxes) {
int64_t splitSize =
- collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
assert(splitSize != ShapedType::kDynamic);
maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
}
@@ -218,7 +218,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
int64_t curr = 0;
for (auto [i, axes] : llvm::enumerate(splitAxes)) {
int64_t splitSize =
- collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
++splitSize; // add one for the total size
ArrayRef<Value> values(&offsets[curr], splitSize);
@@ -264,20 +264,20 @@ struct ConvertProcessMultiIndexOp
SymbolTableCollection symbolTableCollection;
Location loc = op.getLoc();
- auto meshOp = getMesh(op, symbolTableCollection);
- // For now we only support static mesh shapes
- if (ShapedType::isDynamicShape(meshOp.getShape()))
+ auto gridOp = getGrid(op, symbolTableCollection);
+ // For now we only support static grid shapes
+ if (ShapedType::isDynamicShape(gridOp.getShape()))
return failure();
SmallVector<Value> dims;
llvm::transform(
- meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+ gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
});
- Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), meshOp);
+ Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp);
auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
- // optionally extract subset of mesh axes
+ // optionally extract subset of grid axes
auto axes = adaptor.getAxes();
if (!axes.empty()) {
SmallVector<Value> subIndex;
@@ -338,12 +338,12 @@ struct ConvertNeighborsLinearIndicesOp
Location loc = op.getLoc();
SymbolTableCollection symbolTableCollection;
- auto meshOp = getMesh(op, symbolTableCollection);
+ auto gridOp = getGrid(op, symbolTableCollection);
auto mIdx = adaptor.getDevice();
auto orgIdx = mIdx[axes[0]];
SmallVector<Value> dims;
llvm::transform(
- meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+ gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
return arith::ConstantIndexOp::create(rewriter, loc, i).getResult();
});
Value dimSz = dims[axes[0]];
@@ -394,14 +394,14 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
if (!sharding) {
return op->emitError()
- << "Expected SharingOp as defining op for sharding"
+ << "Expected ShardingOp as defining op for sharding"
<< " but found " << adaptor.getSharding()[0].getDefiningOp();
}
// Compute the sharded shape by applying the sharding to the input shape.
// If shardedDimsOffsets is not defined in the sharding, the shard shape is
// computed by dividing the dimension size by the number of shards in that
- // dimension (which is given by the size of the mesh axes provided in
+ // dimension (which is given by the size of the grid axes provided in
// split-axes). Odd elements get distributed to trailing shards. If a
// shardedDimsOffsets is provided, the shard shape is computed by
// subtracting the offset of the current shard from the offset of the next
@@ -431,11 +431,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
SmallVector<Value> multiIdx =
getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
- // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
+ // Get the GridOp, the grid shape is needed to compute the sharded shape.
SymbolTableCollection symbolTableCollection;
- auto meshOp = getMesh(sharding, symbolTableCollection);
- // For now we only support static mesh shapes
- if (ShapedType::isDynamicShape(meshOp.getShape()))
+ auto gridOp = getGrid(sharding, symbolTableCollection);
+ // For now we only support static grid shapes
+ if (ShapedType::isDynamicShape(gridOp.getShape()))
return failure();
auto splitAxes = sharding.getSplitAxes().getAxes();
@@ -455,7 +455,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
tmp);
}
- // With static mesh shape the sizes of the split axes are known.
+ // With static grid shape the sizes of the split axes are known.
// Hence the start/pos for each split axes in shardDimsOffsets can be
// computed statically.
int64_t pos = 0;
@@ -475,10 +475,10 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
// Create a value from the static position in shardDimsOffsets.
Value posVal = arith::ConstantOp::create(rewriter, loc,
rewriter.getIndexAttr(pos));
- // Get the index of the local shard in the mesh axis.
+ // Get the index of the local shard in the grid axis.
Value idx = multiIdx[axes[0]];
auto numShards =
- collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape());
if (shardedDimsOffs) {
// If sharded dims offsets are provided, use them to compute the
// sharded shape.
@@ -556,13 +556,13 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SymbolTableCollection symbolTableCollection;
- auto mesh = adaptor.getMesh();
- mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection);
- if (!meshOp)
- return op->emitError() << "No mesh found for AllReduceOp";
- if (ShapedType::isDynamicShape(meshOp.getShape()))
+ auto grid = adaptor.getGrid();
+ mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
+ if (!gridOp)
+ return op->emitError() << "No grid found for AllReduceOp";
+ if (ShapedType::isDynamicShape(gridOp.getShape()))
return op->emitError()
- << "Dynamic mesh shape not supported in AllReduceOp";
+ << "Dynamic grid shape not supported in AllReduceOp";
ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
Value input = adaptor.getInput();
@@ -592,27 +592,27 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
linalg::CopyOp::create(iBuilder, input, buffer);
// Get an MPI_Comm_split for the AllReduce operation.
- // The color is the linear index of the process in the mesh along the
- // non-reduced axes. The key is the linear index of the process in the mesh
+ // The color is the linear index of the process in the grid along the
+ // non-reduced axes. The key is the linear index of the process in the grid
// along the reduced axes.
- SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+ SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
iBuilder.getIndexType());
SmallVector<Value> myMultiIndex =
- ProcessMultiIndexOp::create(iBuilder, indexResultTypes, mesh)
+ ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
.getResult();
Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
SmallVector<Value> multiKey(myMultiIndex.size(), zero);
- auto redAxes = adaptor.getMeshAxes();
+ auto redAxes = adaptor.getGridAxes();
for (auto axis : redAxes) {
multiKey[axis] = myMultiIndex[axis];
myMultiIndex[axis] = zero;
}
Value color =
- createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder);
+ createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
- Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder);
+ Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);
key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
// Finally split the communicator
@@ -698,8 +698,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
}
auto rank = cast<ShapedType>(array.getType()).getRank();
auto opSplitAxes = adaptor.getSplitAxes().getAxes();
- auto mesh = adaptor.getMesh();
- auto meshOp = getMesh(op, symbolTableCollection);
+ auto grid = adaptor.getGrid();
+ auto gridOp = getGrid(op, symbolTableCollection);
// subviews need Index values
for (auto &sz : haloSizes) {
if (auto value = dyn_cast<Value>(sz))
@@ -745,10 +745,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
- SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+ SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
rewriter.getIndexType());
auto myMultiIndex =
- ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, mesh)
+ ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid)
.getResult();
// traverse all split axes from high to low dim
for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
@@ -759,7 +759,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
// Get the linearized ids of the neighbors (down and up) for the
// given split
auto tmp = rewriter
- .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
+ .create<NeighborsLinearIndicesOp>(loc, grid, myMultiIndex,
splitAxes)
.getResults();
// MPI operates on i32...
@@ -791,7 +791,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
: haloSizes[currHaloDim * 2];
// Check if we need to send and/or receive
- // Processes on the mesh borders have only one neighbor
+ // Processes on the grid borders have only one neighbor
auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
auto hasFrom = arith::CmpIOp::create(
@@ -869,8 +869,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
}
};
-struct ConvertMeshToMPIPass
- : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
+struct ConvertShardToMPIPass
+ : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> {
using Base::Base;
/// Run the dialect converter on the module.
@@ -879,12 +879,12 @@ struct ConvertMeshToMPIPass
RewritePatternSet patterns(ctxt);
ConversionTarget target(getContext());
- // Define a type converter to convert mesh::ShardingType,
+ // Define a type converter to convert shard::ShardingType,
// mostly for use in return operations.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
- // convert mesh::ShardingType to a tuple of RankedTensorTypes
+ // convert shard::ShardingType to a tuple of RankedTensorTypes
typeConverter.addConversion(
[](ShardingType type,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
@@ -920,10 +920,10 @@ struct ConvertMeshToMPIPass
return results;
});
- // No mesh dialect should left after conversion...
- target.addIllegalDialect<mesh::MeshDialect>();
- // ...except the global MeshOp. MeshShapeOp which will get folded later.
- target.addLegalOp<mesh::MeshOp, mesh::MeshShapeOp>();
+ // No shard dialect should left after conversion...
+ target.addIllegalDialect<shard::ShardDialect>();
+ // ...except the global GridOp. GridShapeOp which will get folded later.
+ target.addLegalOp<shard::GridOp, shard::GridShapeOp>();
// Allow all the stuff that our patterns will convert to
target.addLegalDialect<
BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
@@ -951,7 +951,7 @@ struct ConvertMeshToMPIPass
// Folding patterns cannot be mixed with conversion patterns -> extra pass.
patterns.clear();
SymbolTableCollection symbolTableCollection;
- mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection);
+ mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index f96bda603baa6..93682a9375dac 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -27,7 +27,7 @@ add_mlir_dialect_library(MLIRArithTransforms
MLIRInferIntRangeInterface
MLIRIR
MLIRMemRefDialect
- MLIRMeshDialect
+ MLIRShardDialect
MLIRPass
MLIRShardingInterface
MLIRTensorDialect
diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
index 3478adcb4a128..2906999f8def7 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp
@@ -6,22 +6,22 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/IR/DialectRegistry.h"
using namespace mlir;
using namespace mlir::arith;
-using namespace mlir::mesh;
+using namespace mlir::shard;
namespace {
// Sharding of arith.constant
// RankedTensor constants can be sharded like any other tensor.
// %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-// %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
+// %sharding = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
// Scalar constants are always replicated and need no sharding annotation.
struct ConstantShardingInterface
@@ -48,8 +48,8 @@ struct ConstantShardingInterface
// Otherwise mirror result sharding if it is a tensor constant.
// Otherwise return replication option.
FailureOr<ShardingOption>
- getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings) const {
+ getShardingOption(Operation *op, ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings) const {
assert(resultShardings.size() == 1 &&
"Expecting exactly one result sharding for arith.constant");
auto resultSharding = resultShardings[0];
@@ -61,17 +61,17 @@ struct ConstantShardingInterface
for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) {
axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end());
}
- return ShardingOption(axesArray, resultSharding.getMeshAttr());
+ return ShardingOption(axesArray, resultSharding.getGridAttr());
}
- return ShardingOption({}, resultSharding.getMeshAttr());
+ return ShardingOption({}, resultSharding.getGridAttr());
}
- LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
- IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable,
- OpBuilder &builder) const {
+ LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
auto cOp = cast<ConstantOp>(op);
if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) {
if (!value.isSplat() || !resultShardings[0]) {
@@ -80,15 +80,15 @@ struct ConstantShardingInterface
}
auto sharding = resultShardings[0];
auto newType = cast<RankedTensorType>(shardType(
- cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable),
+ cOp.getType(), getGrid(op, sharding.getGridAttr(), symbolTable),
sharding));
auto newValue = value.resizeSplat(newType);
auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue);
- spmdizationMap.map(op->getResult(0), newOp.getResult());
- spmdizationMap.map(op, newOp.getOperation());
+ partitionMap.map(op->getResult(0), newOp.getResult());
+ partitionMap.map(op, newOp.getOperation());
} else {
// `clone` will populate the mapping of old to new results.
- (void)builder.clone(*op, spmdizationMap);
+ (void)builder.clone(*op, partitionMap);
}
return success();
}
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 3cc52ebc0a8d9..053ee95e92053 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -19,7 +19,7 @@ add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(Math)
add_subdirectory(MemRef)
-add_subdirectory(Mesh)
+add_subdirectory(Shard)
add_subdirectory(MLProgram)
add_subdirectory(MPI)
add_subdirectory(NVGPU)
diff --git a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
index eb6b59bb00f1b..1b18ef2dd04a7 100644
--- a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
+++ b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp
@@ -8,7 +8,7 @@
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
-#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h"
+#include "mlir/Dialect/Func/Extensions/ShardingExtensions.h"
using namespace mlir;
diff --git a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
index 47363f48f95cc..87ef51e63f1da 100644
--- a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
+++ b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt
@@ -1,7 +1,7 @@
set(LLVM_OPTIONAL_SOURCES
AllExtensions.cpp
InlinerExtension.cpp
- MeshShardingExtensions.cpp
+ ShardingExtensions.cpp
)
add_mlir_extension_library(MLIRFuncInlinerExtension
@@ -17,8 +17,8 @@ add_mlir_extension_library(MLIRFuncInlinerExtension
MLIRFuncDialect
)
-add_mlir_extension_library(MLIRFuncMeshShardingExtensions
- MeshShardingExtensions.cpp
+add_mlir_extension_library(MLIRFuncShardingExtensions
+ ShardingExtensions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Extensions
@@ -38,5 +38,5 @@ add_mlir_extension_library(MLIRFuncAllExtensions
LINK_LIBS PUBLIC
MLIRFuncInlinerExtension
- MLIRFuncMeshShardingExtensions
+ MLIRFuncShardingExtensions
)
diff --git a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/ShardingExtensions.cpp
similarity index 68%
rename from mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
rename to mlir/lib/Dialect/Func/Extensions/ShardingExtensions.cpp
index da508cc95bfe1..dfd1348c24441 100644
--- a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Func/Extensions/ShardingExtensions.cpp
@@ -1,4 +1,4 @@
-//===- MeshShardingExtensions.cpp - ---------------------------------------===//
+//===- ShardingExtensions.cpp - ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h"
+#include "mlir/Dialect/Func/Extensions/ShardingExtensions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/IR/MLIRContext.h"
namespace mlir::func {
@@ -16,7 +16,7 @@ namespace mlir::func {
void registerShardingInterfaceExternalModels(DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, FuncDialect *dialect) {
ReturnOp::attachInterface<
- mesh::IndependentParallelIteratorDomainShardingInterface<ReturnOp>>(
+ shard::IndependentParallelIteratorDomainShardingInterface<ReturnOp>>(
*ctx);
});
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
index b6e168e95ee86..7f6ecab2d90f5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
@@ -15,7 +15,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
@@ -119,8 +119,8 @@ void mlir::linalg::LinalgDialect::initialize() {
addInterfaces<LinalgInlinerInterface>();
- declarePromisedInterface<mesh::ShardingInterface, GenericOp>();
- declarePromisedInterfaces<mesh::ShardingInterface,
+ declarePromisedInterface<shard::ShardingInterface, GenericOp>();
+ declarePromisedInterfaces<shard::ShardingInterface,
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
index 281d9f2204486..ba94ad7906ab7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
@@ -10,14 +10,14 @@
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
void mlir::linalg::registerAllDialectInterfaceImplementations(
DialectRegistry ®istry) {
registerBufferizableOpInterfaceExternalModels(registry);
- registerMeshShardingInterfaceExternalModels(registry);
+ registerShardingInterfaceExternalModels(registry);
registerSubsetOpInterfaceExternalModels(registry);
registerTilingInterfaceExternalModels(registry);
registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 69e6fdabf9a58..70f846e5bbd20 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -24,7 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Interchange.cpp
Loops.cpp
TransposeMatmul.cpp
- MeshShardingInterfaceImpl.cpp
+ ShardingInterfaceImpl.cpp
NamedOpConversions.cpp
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
@@ -68,7 +68,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRIR
MLIRMemRefDialect
MLIRMemRefTransforms
- MLIRMeshTransforms
+ MLIRShardTransforms
MLIRLinalgDialect
MLIRLinalgUtils
MLIRSCFDialect
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
similarity index 66%
rename from mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
rename to mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
index 8208a3123050e..0c0fcff64487f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
@@ -1,4 +1,4 @@
-//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
+//===- ShardingInterfaceImpl.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,18 +6,18 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
@@ -38,13 +38,13 @@
namespace mlir::linalg {
-using MeshAxis = mesh::MeshAxis;
-using ReductionKind = mesh::ReductionKind;
-using MeshSharding = mesh::MeshSharding;
-using ShardingArray = mesh::ShardingArray;
-using MeshOp = mesh::MeshOp;
+using GridAxis = shard::GridAxis;
+using ReductionKind = shard::ReductionKind;
+using Sharding = shard::Sharding;
+using ShardingArray = shard::ShardingArray;
+using GridOp = shard::GridOp;
-// Returns the corresponding mesh reduction kind for the given arith op.
+// Returns the corresponding grid reduction kind for the given arith op.
static ReductionKind getReductionKind(Operation *op) {
return llvm::TypeSwitch<Operation *, ReductionKind>(op)
// Floating-point operations.
@@ -99,18 +99,18 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
return getReductionKind(reductionOp.value());
}
-static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
+static GridOp getGrid(Operation *op, ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings,
SymbolTableCollection &symbolTable) {
- for (const MeshSharding &sharding : operandShardings) {
+ for (const Sharding &sharding : operandShardings) {
if (sharding) {
- return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
+ return shard::getGrid(op, sharding.getGridAttr(), symbolTable);
}
}
- for (const MeshSharding &sharding : resultShardings) {
+ for (const Sharding &sharding : resultShardings) {
if (sharding) {
- return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
+ return shard::getGrid(op, sharding.getGridAttr(), symbolTable);
}
}
@@ -119,29 +119,29 @@ static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
}
// Choose the operand based on the current process index along the reduction
-// mesh axes.
+// grid axes.
// We need to use the initial value only once to avoid including it in the
// reduction multiple times.
// In each process group only the leading process with linear index 0 would use
// the original operand.
// The other processes would use the reduction operation neutral tensor.
static Value createDestinationPassingStyleInitOperand(
- LinalgOp op, int operandNumber, Value spmdizedOperand,
- ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
+ LinalgOp op, int operandNumber, Value partitiondOperand,
+ ArrayRef<GridAxis> reductionGridAxes, GridOp gridOp,
ImplicitLocOpBuilder &builder) {
- Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
- meshOp.getSymName(), reductionMeshAxes, builder);
+ Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex(
+ gridOp.getSymName(), reductionGridAxes, builder);
Value zero = builder.create<arith::ConstantIndexOp>(0);
Value isLeadProcess = builder.create<arith::CmpIOp>(
builder.getI1Type(), arith::CmpIPredicate::eq,
processLinearIndexInReductionGroup, zero);
- scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
+ scf::IfOp ifOp = builder.create<scf::IfOp>(partitiondOperand.getType(),
isLeadProcess, true, true);
// Then block.
{
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
- builder.create<scf::YieldOp>(spmdizedOperand);
+ builder.create<scf::YieldOp>(partitiondOperand);
}
// Else block.
@@ -149,7 +149,7 @@ static Value createDestinationPassingStyleInitOperand(
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
SmallVector<OpFoldResult> shape =
- tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
+ tensor::getMixedSizes(builder, builder.getLoc(), partitiondOperand);
SmallVector<Operation *> combinerOps;
matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
@@ -169,73 +169,72 @@ static Value createDestinationPassingStyleInitOperand(
return ifOp.getResult(0);
}
-// Create the DPS init operands for the spmdized Linalg op.
-// Return all the new spmdized operands.
+// Create the DPS init operands for the partitiond Linalg op.
+// Return all the new partitiond operands.
static SmallVector<Value> createDestinationPassingStyleInitOperands(
- LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
+ LinalgOp op, GridOp gridOp, ArrayRef<Value> partitiondOperands,
+ ArrayRef<GridAxis> reductionGridAxes, IRMapping &partitionMap,
ImplicitLocOpBuilder &builder) {
// TODO: add support for multiple destination passing style initial value
// operands.
assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
- SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
+ SmallVector<Value> newOperands = llvm::to_vector(partitiondOperands);
auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
- Value spmdizedInitOperand =
- spmdizationMap.lookup(op->getOperands()[operandIdx]);
+ Value partitiondInitOperand =
+ partitionMap.lookup(op->getOperands()[operandIdx]);
newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
- op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
+ op, 0, partitiondInitOperand, reductionGridAxes, gridOp, builder);
return newOperands;
}
static void createAllReduceForResultsWithoutPartialShardings(
- LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
- ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
+ LinalgOp unshardedOp, ArrayRef<GridAxis> opReductionGridAxes,
+ ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
ImplicitLocOpBuilder &builder) {
ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
for (auto [unshardedLinalgOpResult, resultSharding] :
llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
- Value spmdizedLinalgOpResult =
- spmdizationMap.lookup(unshardedLinalgOpResult);
- Value reducedValue = builder.create<mesh::AllReduceOp>(
- spmdizedLinalgOpResult, resultSharding.getMesh(), opReductionMeshAxes,
+ Value partitiondLinalgOpResult =
+ partitionMap.lookup(unshardedLinalgOpResult);
+ Value reducedValue = builder.create<shard::AllReduceOp>(
+ partitiondLinalgOpResult, resultSharding.getGrid(), opReductionGridAxes,
reductionKind);
- spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
+ partitionMap.map(unshardedLinalgOpResult, reducedValue);
}
}
-static void spmdizeLinalgOpWithShardedReduction(
- LinalgOp op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
+static void partitionLinalgOpWithShardedReduction(
+ LinalgOp op, ArrayRef<Value> partitiondOperands,
+ ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
ArrayRef<utils::IteratorType> loopIteratorTypes,
- ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
- IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
+ ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators,
+ IRMapping &partitionMap, SymbolTableCollection &symbolTable,
ImplicitLocOpBuilder &builder) {
- MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
- SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
- loopIteratorTypes, meshAxisAssignmentForLoopIterators);
- SmallVector<Value> spmdizedLinalgOpOperands =
- createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
- reductionMeshAxes,
- spmdizationMap, builder);
- // We must not change the operand mappings of the original spmdizationMap as
- // they are the mappings for the whole spmdization blob and may be used by
+ GridOp grid = getGrid(op, operandShardings, resultShardings, symbolTable);
+ SmallVector<GridAxis> reductionGridAxes = shard::getReductionGridAxes(
+ loopIteratorTypes, gridAxisAssignmentForLoopIterators);
+ SmallVector<Value> partitiondLinalgOpOperands =
+ createDestinationPassingStyleInitOperands(op, grid, partitiondOperands,
+ reductionGridAxes, partitionMap,
+ builder);
+ // We must not change the operand mappings of the original partitionMap as
+ // they are the mappings for the whole partition blob and may be used by
// others.
- IRMapping internalSpmdizationMap;
- for (auto [unshardedOperand, spmdizedOperand] :
- llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
- internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
+ IRMapping internalPartitionMap;
+ for (auto [unshardedOperand, partitiondOperand] :
+ llvm::zip_equal(op->getOperands(), partitiondLinalgOpOperands)) {
+ internalPartitionMap.map(unshardedOperand, partitiondOperand);
}
- spmdizeTriviallyShardableOperation(
- *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
- internalSpmdizationMap, symbolTable, builder);
+ partitionTriviallyShardableOperation(
+ *op, partitiondLinalgOpOperands, operandShardings, resultShardings,
+ internalPartitionMap, symbolTable, builder);
for (Value result : op->getResults()) {
- spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
+ partitionMap.map(result, internalPartitionMap.lookup(result));
}
// Handle partial shardings.
createAllReduceForResultsWithoutPartialShardings(
- op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
+ op, reductionGridAxes, resultShardings, partitionMap, builder);
}
namespace {
@@ -245,7 +244,7 @@ namespace {
// permutations.
template <typename Op>
struct StructuredOpShardingInterface
- : public mesh::ShardingInterface::ExternalModel<
+ : public shard::ShardingInterface::ExternalModel<
StructuredOpShardingInterface<Op>, Op> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
@@ -274,16 +273,16 @@ struct StructuredOpShardingInterface
[](unsigned count, utils::IteratorType iter) {
return count + (iter == utils::IteratorType::reduction);
});
- mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
+ shard::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
}
- LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
- IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable,
- OpBuilder &builder) const {
+ LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
@@ -299,20 +298,20 @@ struct StructuredOpShardingInterface
SmallVector<utils::IteratorType> loopIteratorTypes =
linalgOp.getIteratorTypesArray();
- ShardingArray meshAxisAssignmentForLoopIterators =
- getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
+ ShardingArray gridAxisAssignmentForLoopIterators =
+ getGridAxisAssignmentForLoopIterators(operandShardings, resultShardings,
loopIteratorTypes, indexingMaps);
- if (mesh::isAtLeastOneReductionIteratorSharded(
- loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ if (shard::isAtLeastOneReductionIteratorSharded(
+ loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
- spmdizeLinalgOpWithShardedReduction(
- linalgOp, spmdizedOperands, operandShardings, resultShardings,
- loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
+ partitionLinalgOpWithShardedReduction(
+ linalgOp, partitiondOperands, operandShardings, resultShardings,
+ loopIteratorTypes, gridAxisAssignmentForLoopIterators, partitionMap,
symbolTable, implicitLocBuilder);
} else {
- spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
- operandShardings, resultShardings,
- spmdizationMap, symbolTable, builder);
+ partitionTriviallyShardableOperation(*op, partitiondOperands,
+ operandShardings, resultShardings,
+ partitionMap, symbolTable, builder);
}
return success();
@@ -332,7 +331,7 @@ static void registerAll(MLIRContext *ctx) {
(registerOne<OpTypes>(ctx), ...);
}
-void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry) {
+void registerShardingInterfaceExternalModels(DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
DialectRegistry registry;
registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Shard/CMakeLists.txt
similarity index 100%
rename from mlir/lib/Dialect/Mesh/CMakeLists.txt
rename to mlir/lib/Dialect/Shard/CMakeLists.txt
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Shard/IR/CMakeLists.txt
similarity index 59%
rename from mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
rename to mlir/lib/Dialect/Shard/IR/CMakeLists.txt
index 3fea4d67430e0..70c6049884e12 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shard/IR/CMakeLists.txt
@@ -1,11 +1,11 @@
-add_mlir_dialect_library(MLIRMeshDialect
- MeshOps.cpp
+add_mlir_dialect_library(MLIRShardDialect
+ ShardOps.cpp
ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shard
DEPENDS
- MLIRMeshIncGen
+ MLIRShardIncGen
LINK_LIBS PUBLIC
MLIRArithDialect
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
similarity index 76%
rename from mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
rename to mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index 28608cb0dd96c..df2fcf4c2f4c8 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -1,4 +1,4 @@
-//===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
+//===- ShardOps.cpp - Shard Dialect Operations ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -37,13 +37,13 @@
#include <optional>
#include <utility>
-#define DEBUG_TYPE "mesh-ops"
+#define DEBUG_TYPE "shard-ops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
-using namespace mlir::mesh;
+using namespace mlir::shard;
-#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc"
namespace {
@@ -74,11 +74,10 @@ static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
return lhs.value() * rhs.value();
}
-SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b,
- const Location &loc,
- llvm::ArrayRef<int64_t> statics,
- ValueRange dynamics,
- Type type) {
+SmallVector<Value>
+mlir::shard::getMixedAsValues(OpBuilder b, const Location &loc,
+ llvm::ArrayRef<int64_t> statics,
+ ValueRange dynamics, Type type) {
SmallVector<Value> values;
auto dyn = dynamics.begin();
Type i64 = b.getI64Type();
@@ -102,7 +101,7 @@ SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b,
//===----------------------------------------------------------------------===//
namespace {
-struct MeshInlinerInterface : public DialectInlinerInterface {
+struct ShardInlinerinterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
// Currently no restrictions are encoded for inlining.
bool isLegalToInline(Operation *, Operation *, bool) const final {
@@ -118,44 +117,45 @@ struct MeshInlinerInterface : public DialectInlinerInterface {
} // namespace
//===----------------------------------------------------------------------===//
-// Mesh dialect
+// Shard dialect
//===----------------------------------------------------------------------===//
-void MeshDialect::initialize() {
+void ShardDialect::initialize() {
addOperations<
#define GET_OP_LIST
-#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
-#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
-#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
>();
- addInterface<MeshInlinerInterface>();
+ addInterface<ShardInlinerinterface>();
}
-Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
- Type type, Location loc) {
+Operation *ShardDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
return arith::ConstantOp::materialize(builder, value, type, loc);
}
//===----------------------------------------------------------------------===//
-// Mesh utilities
+// Shard utilities
//===----------------------------------------------------------------------===//
-static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
- FlatSymbolRefAttr meshSymbol,
+static FailureOr<GridOp> getGridAndVerify(Operation *op,
+ FlatSymbolRefAttr gridSymbol,
SymbolTableCollection &symbolTable) {
- mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
- if (!mesh) {
- return op->emitError() << "Undefined required mesh symbol \""
- << meshSymbol.getValue() << "\".";
+ shard::GridOp grid = getGridOrNull(op, gridSymbol, symbolTable);
+ if (!grid) {
+ return op->emitError() << "Undefined required grid symbol \""
+ << gridSymbol.getValue() << "\".";
}
- return mesh;
+ return grid;
}
template <typename It>
@@ -175,20 +175,20 @@ bool isUnique(It begin, It end) {
return true;
}
-static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
- MeshOp mesh) {
- SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+static LogicalResult verifyGridAxes(Location loc, ArrayRef<GridAxis> axes,
+ GridOp grid) {
+ SmallVector<GridAxis> sorted = llvm::to_vector(axes);
llvm::sort(sorted);
if (!isUnique(sorted.begin(), sorted.end())) {
- return emitError(loc) << "Mesh axes contains duplicate elements.";
+ return emitError(loc) << "Grid axes contains duplicate elements.";
}
- MeshAxis rank = mesh.getRank();
+ GridAxis rank = grid.getRank();
for (auto axis : axes) {
if (axis >= rank || axis < 0) {
return emitError(loc)
- << "0-based mesh axis index " << axis
- << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
+ << "0-based grid axis index " << axis
+ << " is out of bounds. The referenced grid \"" << grid.getSymName()
<< "\" is of rank " << rank << ".";
}
}
@@ -197,22 +197,22 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
}
template <typename Op>
-static FailureOr<MeshOp>
-getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
- auto mesh =
- ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
- if (failed(mesh)) {
+static FailureOr<GridOp>
+getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
+ auto grid =
+ ::getGridAndVerify(op.getOperation(), op.getGridAttr(), symbolTable);
+ if (failed(grid)) {
return failure();
}
- if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
+ if (failed(verifyGridAxes(op.getLoc(), op.getGridAxes(), grid.value()))) {
return failure();
}
- return mesh;
+ return grid;
}
-template <typename InShape, typename MeshShape, typename SplitAxes,
+template <typename InShape, typename GridShape, typename SplitAxes,
typename OutShape>
-static void shardShape(const InShape &inShape, const MeshShape &meshShape,
+static void shardShape(const InShape &inShape, const GridShape &gridShape,
const SplitAxes &splitAxes, OutShape &outShape,
ArrayRef<int64_t> shardedDimsOffsets = {},
ArrayRef<int64_t> haloSizes = {}) {
@@ -226,7 +226,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
llvm::adl_begin(outShape));
if (!shardedDimsOffsets.empty()) {
- auto isDynShape = ShapedType::isDynamicShape(meshShape);
+ auto isDynShape = ShapedType::isDynamicShape(gridShape);
uint64_t pos = 1;
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
if (!innerSplitAxes.empty()) {
@@ -238,7 +238,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
// non-uniform offs in shardedDimsOffsets.
uint64_t numShards = 0;
for (auto i : innerSplitAxes.asArrayRef()) {
- numShards += meshShape[i];
+ numShards += gridShape[i];
}
for (size_t i = 1; i < numShards; ++i) {
if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] !=
@@ -256,7 +256,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
outShape[tensorAxis] = shardDimension(
inShape[tensorAxis],
- collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
+ collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), gridShape));
}
if (!haloSizes.empty()) {
@@ -279,25 +279,25 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
}
}
-ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
- MeshSharding sharding) {
+ShapedType shard::shardShapedType(ShapedType shape, GridOp grid,
+ Sharding sharding) {
using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
SmallVector<Dim> resShapeArr(shape.getShape().size());
- shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
+ shardShape(shape.getShape(), grid.getShape(), sharding.getSplitAxes(),
resShapeArr, sharding.getStaticShardedDimsOffsets(),
sharding.getStaticHaloSizes());
return shape.clone(resShapeArr);
}
-Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
+Type shard::shardType(Type type, GridOp grid, Sharding sharding) {
RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
if (rankedTensorType && !rankedTensorType.getShape().empty()) {
- return shardShapedType(rankedTensorType, mesh, sharding);
+ return shardShapedType(rankedTensorType, grid, sharding);
}
return type;
}
-static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
+static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding,
Value &operandValue,
Operation *operandOp,
OpBuilder &builder,
@@ -336,9 +336,9 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
}
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
- OpResult result,
- OpBuilder &builder) {
+void mlir::shard::maybeInsertTargetShardingAnnotation(Sharding sharding,
+ OpResult result,
+ OpBuilder &builder) {
ShardOp newShardOp;
SmallVector<std::pair<Value, Operation *>> uses;
for (auto &use : result.getUses()) {
@@ -350,9 +350,9 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
}
}
-void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
- OpOperand &operand,
- OpBuilder &builder) {
+void mlir::shard::maybeInsertSourceShardingAnnotation(Sharding sharding,
+ OpOperand &operand,
+ OpBuilder &builder) {
OpBuilder::InsertionGuard insertionGuard(builder);
Value operandValue = operand.get();
Operation *operandSrcOp = operandValue.getDefiningOp();
@@ -404,18 +404,18 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
}
//===----------------------------------------------------------------------===//
-// mesh.mesh op
+// shard.grid op
//===----------------------------------------------------------------------===//
-LogicalResult MeshOp::verify() {
+LogicalResult GridOp::verify() {
int64_t rank = getRank();
if (rank <= 0)
- return emitOpError("rank of mesh is expected to be a positive integer");
+ return emitOpError("rank of grid is expected to be a positive integer");
for (int64_t dimSize : getShape()) {
if (dimSize < 0 && ShapedType::isStatic(dimSize))
- return emitOpError("dimension size of a mesh is expected to be "
+ return emitOpError("dimension size of a grid is expected to be "
"non-negative or dynamic");
}
@@ -423,21 +423,21 @@ LogicalResult MeshOp::verify() {
}
//===----------------------------------------------------------------------===//
-// mesh.mesh_shape op
+// shard.grid_shape op
//===----------------------------------------------------------------------===//
LogicalResult
-MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
- if (failed(mesh)) {
+GridShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+ if (failed(grid)) {
return failure();
}
- if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+ if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
return failure();
}
size_t expectedResultsCount =
- getAxes().empty() ? mesh->getRank() : getAxes().size();
+ getAxes().empty() ? grid->getRank() : getAxes().size();
if (getResult().size() != expectedResultsCount) {
return emitError() << "Unexpected number of results " << getResult().size()
<< ". Expected " << expectedResultsCount << ".";
@@ -446,53 +446,53 @@ MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
-void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- MeshOp mesh) {
- build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
+void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ GridOp grid) {
+ build(odsBuilder, odsState, grid, SmallVector<GridAxis>());
}
-void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- MeshOp mesh, ArrayRef<MeshAxis> axes) {
+void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ GridOp grid, ArrayRef<GridAxis> axes) {
build(odsBuilder, odsState,
- SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
+ SmallVector<Type>(axes.empty() ? grid.getRank() : axes.size(),
odsBuilder.getIndexType()),
- mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
+ grid.getSymName(), GridAxesAttr::get(odsBuilder.getContext(), axes));
}
-void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef mesh, ArrayRef<MeshAxis> axes) {
+void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef grid, ArrayRef<GridAxis> axes) {
assert(!axes.empty());
build(odsBuilder, odsState,
- SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
- MeshAxesAttr::get(odsBuilder.getContext(), axes));
+ SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
+ GridAxesAttr::get(odsBuilder.getContext(), axes));
}
-void MeshShapeOp::getAsmResultNames(
+void GridShapeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
- setNameFn(getResults()[0], "mesh_shape");
+ setNameFn(getResults()[0], "grid_shape");
}
//===----------------------------------------------------------------------===//
-// mesh.sharding
+// shard.sharding
//===----------------------------------------------------------------------===//
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- FlatSymbolRefAttr mesh,
- ArrayRef<MeshAxesAttr> split_axes,
+ FlatSymbolRefAttr grid,
+ ArrayRef<GridAxesAttr> split_axes,
ArrayRef<int64_t> static_halos,
ArrayRef<int64_t> static_offsets) {
return build(
- b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
+ b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {});
}
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes,
+ llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes,
ArrayRef<int64_t> static_halos,
ArrayRef<int64_t> static_offsets) {
- return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh),
- MeshAxesArrayAttr::get(b.getContext(), split_axes),
+ return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid),
+ GridAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {},
::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets),
{});
@@ -500,7 +500,7 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
void ShardingOp::build(
::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
+ FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> split_axes,
::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes,
::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) {
mlir::SmallVector<int64_t> staticHalos, staticDims;
@@ -508,16 +508,16 @@ void ShardingOp::build(
dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos);
dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims);
return build(
- b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
+ b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes),
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos,
::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims);
}
void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
- mlir::mesh::MeshSharding from) {
+ mlir::shard::Sharding from) {
- build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(),
- MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
+ build(b, odsState, ShardingType::get(b.getContext()), from.getGridAttr(),
+ GridAxesArrayAttr::get(b.getContext(), from.getSplitAxes()),
from.getStaticShardedDimsOffsets().empty()
? DenseI64ArrayAttr()
: b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()),
@@ -529,21 +529,21 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState,
}
LogicalResult ShardingOp::verify() {
- llvm::SmallSet<MeshAxis, 4> visitedAxes;
+ llvm::SmallSet<GridAxis, 4> visitedAxes;
- auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
- for (MeshAxis axis : axesArray) {
+ auto checkGridAxis = [&](ArrayRef<GridAxis> axesArray) -> LogicalResult {
+ for (GridAxis axis : axesArray) {
if (axis < 0)
- return emitError() << "mesh axis is expected to be non-negative";
+ return emitError() << "grid axis is expected to be non-negative";
if (!visitedAxes.insert(axis).second)
- return emitError() << "mesh axis duplicated";
+ return emitError() << "grid axis duplicated";
}
return success();
};
for (auto subAxes : getSplitAxes().getAxes()) {
- ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
- if (failed(checkMeshAxis(subAxesArray)))
+ ArrayRef<GridAxis> subAxesArray = subAxes.asArrayRef();
+ if (failed(checkGridAxis(subAxesArray)))
return failure();
}
@@ -572,26 +572,26 @@ void ShardingOp::getAsmResultNames(
}
LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
- if (failed(mesh)) {
+ auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+ if (failed(grid)) {
return failure();
}
- if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
+ if (mlir::ShapedType::isDynamicShape(grid->getShape()) &&
getStaticShardedDimsOffsets().size() > 0) {
return emitError() << "sharded dims offsets are not allowed for "
- "devices meshes with dynamic shape.";
+ "device grids with dynamic shape.";
}
auto shardedDimsOffsets = getStaticShardedDimsOffsets();
if (!shardedDimsOffsets.empty()) {
- auto meshShape = mesh.value().getShape();
- assert(ShapedType::isStaticShape(meshShape));
+ auto gridShape = grid.value().getShape();
+ assert(ShapedType::isStaticShape(gridShape));
uint64_t pos = 0;
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
if (!innerSplitAxes.empty()) {
int64_t numShards = 0, off = 0;
for (auto i : innerSplitAxes.asArrayRef()) {
- numShards += meshShape[i];
+ numShards += gridShape[i];
}
for (int64_t i = 0; i <= numShards; ++i) {
if (shardedDimsOffsets.size() <= pos + i) {
@@ -684,11 +684,11 @@ void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
}
//===----------------------------------------------------------------------===//
-// MeshSharding
+// Sharding
//===----------------------------------------------------------------------===//
-bool MeshSharding::equalSplitAxes(const MeshSharding &rhs) const {
- if (getMesh() != rhs.getMesh()) {
+bool Sharding::equalSplitAxes(const Sharding &rhs) const {
+ if (getGrid() != rhs.getGrid()) {
return false;
}
@@ -701,16 +701,16 @@ bool MeshSharding::equalSplitAxes(const MeshSharding &rhs) const {
}
return llvm::all_of(llvm::drop_begin(getSplitAxes(), minSize),
- std::mem_fn(&MeshAxesAttr::empty)) &&
+ std::mem_fn(&GridAxesAttr::empty)) &&
llvm::all_of(llvm::drop_begin(rhs.getSplitAxes(), minSize),
- std::mem_fn(&MeshAxesAttr::empty));
+ std::mem_fn(&GridAxesAttr::empty));
}
-bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const {
+bool Sharding::equalHaloAndShardSizes(const Sharding &rhs) const {
return equalShardSizes(rhs) && equalHaloSizes(rhs);
}
-bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
+bool Sharding::equalShardSizes(const Sharding &rhs) const {
if (rhs.getStaticShardedDimsOffsets().size() !=
getStaticShardedDimsOffsets().size() ||
!llvm::equal(getStaticShardedDimsOffsets(),
@@ -726,7 +726,7 @@ bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const {
return true;
}
-bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
+bool Sharding::equalHaloSizes(const Sharding &rhs) const {
if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
!llvm::equal(getStaticHaloSizes(), rhs.getStaticHaloSizes())) {
return false;
@@ -738,45 +738,43 @@ bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const {
return true;
}
-bool MeshSharding::operator==(Value rhs) const {
+bool Sharding::operator==(Value rhs) const {
return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs);
}
-bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
+bool Sharding::operator!=(Value rhs) const { return !(*this == rhs); }
-bool MeshSharding::operator==(const MeshSharding &rhs) const {
+bool Sharding::operator==(const Sharding &rhs) const {
return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs);
}
-bool MeshSharding::operator!=(const MeshSharding &rhs) const {
- return !(*this == rhs);
-}
+bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); }
-MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {}
+Sharding::Sharding(::mlir::FlatSymbolRefAttr grid_) : grid(grid_) {}
-MeshSharding::MeshSharding(Value rhs) {
+Sharding::Sharding(Value rhs) {
auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
assert(shardingOp && "expected sharding op");
auto splitAxes = shardingOp.getSplitAxes().getAxes();
// If splitAxes are empty, use "empty" constructor.
if (splitAxes.empty()) {
- *this = MeshSharding(shardingOp.getMeshAttr());
+ *this = Sharding(shardingOp.getGridAttr());
return;
}
*this =
- get(shardingOp.getMeshAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
+ get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(),
shardingOp.getStaticShardedDimsOffsets(),
SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets()));
}
-MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
- ArrayRef<MeshAxesAttr> split_axes_,
- ArrayRef<int64_t> static_halo_sizes_,
- ArrayRef<int64_t> static_sharded_dims_offsets_,
- ArrayRef<Value> dynamic_halo_sizes_,
- ArrayRef<Value> dynamic_sharded_dims_offsets_) {
- MeshSharding res(mesh_);
+Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_,
+ ArrayRef<GridAxesAttr> split_axes_,
+ ArrayRef<int64_t> static_halo_sizes_,
+ ArrayRef<int64_t> static_sharded_dims_offsets_,
+ ArrayRef<Value> dynamic_halo_sizes_,
+ ArrayRef<Value> dynamic_sharded_dims_offsets_) {
+ Sharding res(grid_);
if (split_axes_.empty()) {
return res;
}
@@ -784,7 +782,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
res.split_axes.resize(split_axes_.size());
for (auto [i, axis] : llvm::enumerate(split_axes_)) {
res.split_axes[i] =
- MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
+ GridAxesAttr::get(grid_.getContext(), axis.asArrayRef());
}
auto clone = [](const auto src, auto &dst) {
@@ -801,7 +799,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
}
//===----------------------------------------------------------------------===//
-// mesh.shard_shape
+// shard.shard_shape
//===----------------------------------------------------------------------===//
void ShardShapeOp::getAsmResultNames(
@@ -820,7 +818,7 @@ void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
}
//===----------------------------------------------------------------------===//
-// mesh.shard op
+// shard.shard op
//===----------------------------------------------------------------------===//
void ShardOp::getAsmResultNames(
@@ -850,10 +848,10 @@ class FoldDuplicateShardOp final : public OpRewritePattern<ShardOp> {
if (!otherOp || !otherOp->isBeforeInBlock(op)) {
return failure();
}
- // Create a MeshSharding object for the current and the other ShardOp
+ // Create a Sharding object for the current and the other ShardOp
// If the two are equal replace current op with the other op.
- MeshSharding currentSharding(op.getSharding());
- MeshSharding otherSharding(otherOp.getSharding());
+ Sharding currentSharding(op.getSharding());
+ Sharding otherSharding(otherOp.getSharding());
if (currentSharding == otherSharding) {
b.replaceAllUsesWith(op.getResult(), otherOp.getResult());
b.eraseOp(op.getOperation());
@@ -876,21 +874,21 @@ void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
}
//===----------------------------------------------------------------------===//
-// mesh.process_multi_index op
+// shard.process_multi_index op
//===----------------------------------------------------------------------===//
LogicalResult
ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
- if (failed(mesh)) {
+ auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+ if (failed(grid)) {
return failure();
}
- if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+ if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) {
return failure();
}
size_t expectedResultsCount =
- getAxes().empty() ? mesh->getRank() : getAxes().size();
+ getAxes().empty() ? grid->getRank() : getAxes().size();
if (getResult().size() != expectedResultsCount) {
return emitError() << "Unexpected number of results " << getResult().size()
<< ". Expected " << expectedResultsCount << ".";
@@ -900,17 +898,17 @@ ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- MeshOp mesh) {
+ GridOp grid) {
build(odsBuilder, odsState,
- SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
- mesh.getSymName(), ArrayRef<MeshAxis>());
+ SmallVector<Type>(grid.getRank(), odsBuilder.getIndexType()),
+ grid.getSymName(), ArrayRef<GridAxis>());
}
void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- StringRef mesh, ArrayRef<MeshAxis> axes) {
+ StringRef grid, ArrayRef<GridAxis> axes) {
build(odsBuilder, odsState,
- SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
- MeshAxesAttr::get(odsBuilder.getContext(), axes));
+ SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid,
+ GridAxesAttr::get(odsBuilder.getContext(), axes));
}
void ProcessMultiIndexOp::getAsmResultNames(
@@ -919,21 +917,21 @@ void ProcessMultiIndexOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.process_linear_index op
+// shard.process_linear_index op
//===----------------------------------------------------------------------===//
LogicalResult
ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
- if (failed(mesh)) {
+ auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+ if (failed(grid)) {
return failure();
}
return success();
}
void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
- OperationState &odsState, MeshOp mesh) {
- build(odsBuilder, odsState, mesh.getSymName());
+ OperationState &odsState, GridOp grid) {
+ build(odsBuilder, odsState, grid.getSymName());
}
void ProcessLinearIndexOp::getAsmResultNames(
@@ -942,13 +940,13 @@ void ProcessLinearIndexOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.neighbors_linear_indices op
+// shard.neighbors_linear_indices op
//===----------------------------------------------------------------------===//
LogicalResult
NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
- if (failed(mesh)) {
+ auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+ if (failed(grid)) {
return failure();
}
return success();
@@ -967,12 +965,12 @@ void NeighborsLinearIndicesOp::getAsmResultNames(
namespace {
template <typename Op>
-struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
+struct EmptyGridAxesCanonicalizationPattern : OpRewritePattern<Op> {
using OpRewritePattern<Op>::OpRewritePattern;
LogicalResult matchAndRewrite(Op op,
PatternRewriter &rewriter) const override {
- auto meshAxes = op.getMeshAxes();
- if (!meshAxes.empty()) {
+ auto gridAxes = op.getGridAxes();
+ if (!gridAxes.empty()) {
return failure();
}
if (op.getInput().getType() != op.getResult().getType()) {
@@ -990,24 +988,24 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
ArrayRef<int64_t> device,
Operation::operand_range deviceDynamic,
- ArrayRef<MeshAxis> meshAxes,
- ArrayRef<int64_t> meshShape) {
- if (device.size() != meshAxes.size()) {
+ ArrayRef<GridAxis> gridAxes,
+ ArrayRef<int64_t> gridShape) {
+ if (device.size() != gridAxes.size()) {
return emitError(loc) << "In-group device \"" << deviceName
<< "\" has unexpected multi-index size "
- << device.size() << ". Expected " << meshAxes.size()
+ << device.size() << ". Expected " << gridAxes.size()
<< ".";
}
for (size_t i = 0; i < device.size(); ++i) {
if (ShapedType::isStatic(device[i]) &&
- ShapedType::isStatic(meshShape[meshAxes[i]]) &&
- meshShape[meshAxes[i]] <= device[i]) {
+ ShapedType::isStatic(gridShape[gridAxes[i]]) &&
+ gridShape[gridAxes[i]] <= device[i]) {
return emitError(loc)
<< "Out of bounds coordinate " << i << " for in-group device \""
<< deviceName << "\"."
<< " Got " << device[i] << ", but expected value in the range [0, "
- << (meshShape[meshAxes[i]] - 1) << "].";
+ << (gridShape[gridAxes[i]] - 1) << "].";
}
}
return success();
@@ -1043,7 +1041,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
static LogicalResult verifyGatherOperandAndResultShape(
Value operand, Value result, int64_t gatherAxis,
- ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+ ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
auto resultRank = cast<ShapedType>(result.getType()).getRank();
if (gatherAxis < 0 || gatherAxis >= resultRank) {
return emitError(result.getLoc())
@@ -1054,7 +1052,7 @@ static LogicalResult verifyGatherOperandAndResultShape(
ShapedType operandType = cast<ShapedType>(operand.getType());
ShapedType resultType = cast<ShapedType>(result.getType());
auto deviceGroupSize =
- DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
+ DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
@@ -1070,7 +1068,7 @@ static LogicalResult verifyGatherOperandAndResultShape(
static LogicalResult verifyAllToAllOperandAndResultShape(
Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
- ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+ ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
ShapedType operandType = cast<ShapedType>(operand.getType());
ShapedType resultType = cast<ShapedType>(result.getType());
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
@@ -1088,7 +1086,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
}
auto deviceGroupSize =
- DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
+ DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
DimensionSize expectedResultConcatDimSize =
@@ -1115,7 +1113,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
static LogicalResult verifyScatterOrSliceOperandAndResultShape(
Value operand, Value result, int64_t tensorAxis,
- ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
+ ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) {
ShapedType operandType = cast<ShapedType>(operand.getType());
ShapedType resultType = cast<ShapedType>(result.getType());
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
@@ -1129,7 +1127,7 @@ static LogicalResult verifyScatterOrSliceOperandAndResultShape(
}
auto deviceGroupSize =
- DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
+ DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape));
auto operandScatterDimSize =
DimensionSize(operandType.getDimSize(tensorAxis));
if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
@@ -1151,8 +1149,8 @@ static LogicalResult verifyScatterOrSliceOperandAndResultShape(
return success();
}
-static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
- ArrayRef<MeshAxis> meshAxes,
+static RankedTensorType sliceResultType(Type operandType, GridOp grid,
+ ArrayRef<GridAxis> gridAxes,
int64_t sliceAxis) {
RankedTensorType operandRankedTensorType =
cast<RankedTensorType>(operandType);
@@ -1163,29 +1161,29 @@ static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
resultShape[sliceAxis] =
operandSliceAxisSize /
- DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
+ DimensionSize(collectiveProcessGroupSize(gridAxes, grid));
return operandRankedTensorType.clone(resultShape);
}
//===----------------------------------------------------------------------===//
-// mesh.all_gather op
+// shard.all_gather op
//===----------------------------------------------------------------------===//
LogicalResult
AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
auto gatherAxis = getGatherAxis().getSExtValue();
return verifyGatherOperandAndResultShape(getOperand(), getResult(),
- gatherAxis, getMeshAxes(),
- mesh.value().getShape());
+ gatherAxis, getGridAxes(),
+ grid.value().getShape());
}
void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context);
}
void AllGatherOp::getAsmResultNames(
@@ -1194,23 +1192,23 @@ void AllGatherOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.all_reduce op
+// shard.all_reduce op
//===----------------------------------------------------------------------===//
LogicalResult
AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- return getMeshAndVerifyAxes(*this, symbolTable);
+ return getGridAndVerifyAxes(*this, symbolTable);
}
void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context);
}
void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- Value input, StringRef mesh,
- ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
- build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
+ Value input, StringRef grid,
+ ArrayRef<GridAxis> gridAxes, ReductionKind reduction) {
+ build(odsBuilder, odsState, input.getType(), grid, gridAxes, input,
reduction);
}
@@ -1220,36 +1218,36 @@ void AllReduceOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.all_slice op
+// shard.all_slice op
//===----------------------------------------------------------------------===//
LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
return verifyScatterOrSliceOperandAndResultShape(
- getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
- mesh.value().getShape());
+ getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(),
+ grid.value().getShape());
}
void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context);
}
void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
+ Value input, GridOp grid, ArrayRef<GridAxis> gridAxes,
int64_t sliceAxis) {
- Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
- build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
+ Type resultType = sliceResultType(input.getType(), grid, gridAxes, sliceAxis);
+ build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes,
sliceAxis);
}
void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- Type resultType, Value input, StringRef mesh,
- ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
- build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
+ Type resultType, Value input, StringRef grid,
+ ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) {
+ build(odsBuilder, odsState, resultType, grid, gridAxes, input,
APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
}
@@ -1259,23 +1257,23 @@ void AllSliceOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.all_to_all op
+// shard.all_to_all op
//===----------------------------------------------------------------------===//
LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
return verifyAllToAllOperandAndResultShape(
getOperand(), getResult(), getSplitAxis().getSExtValue(),
- getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
+ getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape());
}
void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context);
}
void AllToAllOp::getAsmResultNames(
@@ -1284,18 +1282,18 @@ void AllToAllOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.broadcast op
+// shard.broadcast op
//===----------------------------------------------------------------------===//
LogicalResult
BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(),
- mesh.value().getShape()))) {
+ getRootDynamic(), getGridAxes(),
+ grid.value().getShape()))) {
return failure();
}
@@ -1304,7 +1302,7 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context);
}
void BroadcastOp::getAsmResultNames(
@@ -1313,29 +1311,29 @@ void BroadcastOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.gather op
+// shard.gather op
//===----------------------------------------------------------------------===//
LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(),
- mesh.value().getShape()))) {
+ getRootDynamic(), getGridAxes(),
+ grid.value().getShape()))) {
return failure();
}
auto gatherAxis = getGatherAxis().getSExtValue();
return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
- getMeshAxes(),
- mesh.value().getShape());
+ getGridAxes(),
+ grid.value().getShape());
}
void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context);
}
void GatherOp::getAsmResultNames(
@@ -1344,18 +1342,18 @@ void GatherOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.recv op
+// shard.recv op
//===----------------------------------------------------------------------===//
LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
if (getSource() &&
failed(verifyInGroupDevice(getLoc(), getSourceAttrName(),
getSource().value(), getSourceDynamic(),
- getMeshAxes(), mesh.value().getShape()))) {
+ getGridAxes(), grid.value().getShape()))) {
return failure();
}
return success();
@@ -1363,7 +1361,7 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context);
}
void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1371,17 +1369,17 @@ void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
}
//===----------------------------------------------------------------------===//
-// mesh.reduce op
+// shard.reduce op
//===----------------------------------------------------------------------===//
LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(),
- mesh.value().getShape()))) {
+ getRootDynamic(), getGridAxes(),
+ grid.value().getShape()))) {
return failure();
}
@@ -1390,7 +1388,7 @@ LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context);
}
void ReduceOp::getAsmResultNames(
@@ -1399,24 +1397,24 @@ void ReduceOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.reduce_scatter op
+// shard.reduce_scatter op
//===----------------------------------------------------------------------===//
LogicalResult
ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
return verifyScatterOrSliceOperandAndResultShape(
- getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
- mesh.value().getShape());
+ getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(),
+ grid.value().getShape());
}
void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context);
}
void ReduceScatterOp::getAsmResultNames(
@@ -1425,29 +1423,29 @@ void ReduceScatterOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.scatter op
+// shard.scatter op
//===----------------------------------------------------------------------===//
LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
- getRootDynamic(), getMeshAxes(),
- mesh.value().getShape()))) {
+ getRootDynamic(), getGridAxes(),
+ grid.value().getShape()))) {
return failure();
}
auto scatterAxis = getScatterAxis().getSExtValue();
return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
- scatterAxis, getMeshAxes(),
- mesh.value().getShape());
+ scatterAxis, getGridAxes(),
+ grid.value().getShape());
}
void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context);
}
void ScatterOp::getAsmResultNames(
@@ -1456,17 +1454,17 @@ void ScatterOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.send op
+// shard.send op
//===----------------------------------------------------------------------===//
LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
getDestination(), getDestinationDynamic(),
- getMeshAxes(), mesh.value().getShape()))) {
+ getGridAxes(), grid.value().getShape()))) {
return failure();
}
return success();
@@ -1474,7 +1472,7 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
+ patterns.add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context);
}
void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1482,20 +1480,20 @@ void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
}
//===----------------------------------------------------------------------===//
-// mesh.shift op
+// shard.shift op
//===----------------------------------------------------------------------===//
LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerifyAxes(*this, symbolTable);
+ if (failed(grid)) {
return failure();
}
- auto meshAxes = getMeshAxes();
+ auto gridAxes = getGridAxes();
auto shiftAxis = getShiftAxis().getZExtValue();
- if (!llvm::is_contained(meshAxes, shiftAxis)) {
+ if (!llvm::is_contained(gridAxes, shiftAxis)) {
return emitError() << "Invalid shift axis " << shiftAxis
- << ". It must be one of the grouping mesh axes.";
+ << ". It must be one of the grouping grid axes.";
}
return success();
@@ -1504,7 +1502,7 @@ LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
// TODO: remove op when offset is 0 or if it is a rotate with and
- // offset % shift_axis_mesh_dim_size == 0.
+ // offset % shift_axis_grid_dim_size == 0.
}
void ShiftOp::getAsmResultNames(
@@ -1513,13 +1511,13 @@ void ShiftOp::getAsmResultNames(
}
//===----------------------------------------------------------------------===//
-// mesh.update_halo op
+// shard.update_halo op
//===----------------------------------------------------------------------===//
LogicalResult
UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
- if (failed(mesh)) {
+ auto grid = getGridAndVerify(getOperation(), getGridAttr(), symbolTable);
+ if (failed(grid)) {
return failure();
}
@@ -1531,12 +1529,12 @@ UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc"
#define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc"
#define GET_TYPEDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc"
-#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"
+#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc"
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Shard/Interfaces/CMakeLists.txt
similarity index 76%
rename from mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
rename to mlir/lib/Dialect/Shard/Interfaces/CMakeLists.txt
index afe76b539846a..01e8e56dd391d 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shard/Interfaces/CMakeLists.txt
@@ -2,7 +2,7 @@ add_mlir_library(MLIRShardingInterface
ShardingInterface.cpp
ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shard
DEPENDS
MLIRShardingInterfaceIncGen
@@ -10,7 +10,7 @@ add_mlir_library(MLIRShardingInterface
LINK_LIBS PUBLIC
MLIRDialectUtils
MLIRIR
- MLIRMeshDialect
+ MLIRShardDialect
MLIRTensorDialect
MLIRSupport
)
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
similarity index 70%
rename from mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
rename to mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
index 6b3d49e08b549..3e8f10cc27fca 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
@@ -24,9 +24,9 @@
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
-using namespace mlir::mesh;
+using namespace mlir::shard;
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.cpp.inc"
//===----------------------------------------------------------------------===//
// common util functions
@@ -93,40 +93,39 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
}
template <typename T>
-SmallVector<MeshAxesAttr>
+SmallVector<GridAxesAttr>
fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) {
- SmallVector<MeshAxesAttr> res;
+ SmallVector<GridAxesAttr> res;
for (const auto &v : vec) {
- res.emplace_back(MeshAxesAttr::get(ctxt, v));
+ res.emplace_back(GridAxesAttr::get(ctxt, v));
}
return res;
}
//===----------------------------------------------------------------------===//
-// mesh::getMeshSharding
+// shard::getSharding
//===----------------------------------------------------------------------===//
-FailureOr<std::pair<bool, MeshSharding>>
-mesh::getMeshSharding(OpResult result) {
+FailureOr<std::pair<bool, Sharding>> shard::getSharding(OpResult result) {
Value val = cast<Value>(result);
bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
- auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+ auto shardOp = llvm::dyn_cast<shard::ShardOp>(user);
if (!shardOp)
return false;
return !shardOp.getAnnotateForUsers();
});
if (anyShardedForDef) {
- // expected to have exact one use if it has a use of `mesh.shard` without
+ // expected to have exact one use if it has a use of `shard.shard` without
// unit attr annotate_for_users
if (!val.hasOneUse())
return failure();
- auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
- return std::make_pair(false, MeshSharding(shardOp.getSharding()));
+ auto shardOp = llvm::cast<shard::ShardOp>(*val.getUsers().begin());
+ return std::make_pair(false, Sharding(shardOp.getSharding()));
}
bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
- auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+ auto shardOp = llvm::dyn_cast<shard::ShardOp>(user);
if (!shardOp)
return false;
return shardOp.getAnnotateForUsers();
@@ -138,24 +137,23 @@ mesh::getMeshSharding(OpResult result) {
if (shardOp)
shardOps.push_back(shardOp);
}
- MeshSharding shardForDef = shardOps[0].getSharding();
+ Sharding shardForDef = shardOps[0].getSharding();
for (size_t i = 1; i < shardOps.size(); ++i) {
- // TODO: Deduce a reasonable mesh sharding attr for def when they are
+ // TODO: Deduce a reasonable grid sharding attr for def when they are
// different
assert(shardForDef == shardOps[i].getSharding() &&
- "only support all shard ops have the same mesh sharding attr");
+ "only support all shard ops have the same grid sharding attr");
}
return std::make_pair(true, shardForDef);
}
return failure();
}
-FailureOr<std::pair<bool, MeshSharding>>
-mesh::getMeshSharding(OpOperand &opOperand) {
+FailureOr<std::pair<bool, Sharding>> shard::getSharding(OpOperand &opOperand) {
Value val = opOperand.get();
if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
return std::make_pair(shardOp.getAnnotateForUsers(),
- MeshSharding(shardOp.getSharding()));
+ Sharding(shardOp.getSharding()));
return failure();
}
@@ -164,7 +162,7 @@ mesh::getMeshSharding(OpOperand &opOperand) {
// ShardingInterface::verifyShardingInterfaceImpl
//===----------------------------------------------------------------------===//
-LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
+LogicalResult shard::ShardingInterface::verifyShardingInterfaceImpl() {
Operation *op = getOperation();
// check operands and results type
@@ -201,7 +199,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
// ShardingInterface::printLoopTypesAndIndexingMaps
//===----------------------------------------------------------------------===//
-void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
+void shard::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
os << "print loop types and indexing maps for: \n";
getOperation()->print(os);
os << "\n";
@@ -222,15 +220,15 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
namespace {
-// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
+// Update the given `shardingOption` according to `gridAxes` and `loopIdx`
static LogicalResult fillShardingOption(Operation *op,
ShardingOption &shardingOption,
- FlatSymbolRefAttr mesh,
- ArrayRef<MeshAxis> meshAxes,
+ FlatSymbolRefAttr grid,
+ ArrayRef<GridAxis> gridAxes,
unsigned loopIdx) {
- if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) ||
+ if ((shardingOption.grid && grid && shardingOption.grid != grid) ||
(!shardingOption.shardingArray[loopIdx].empty() &&
- shardingOption.shardingArray[loopIdx] != meshAxes)) {
+ shardingOption.shardingArray[loopIdx] != gridAxes)) {
LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator "
<< loopIdx << "\n");
return failure();
@@ -239,28 +237,28 @@ static LogicalResult fillShardingOption(Operation *op,
if (i == loopIdx)
continue;
- for (MeshAxis axis : meshAxes) {
+ for (GridAxis axis : gridAxes) {
if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
- LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
+ LLVM_DEBUG(DBGS() << "sharding option conflicts because grid axes "
<< axis << " duplicate");
return failure();
}
}
}
- if (mesh)
- shardingOption.mesh = mesh;
+ if (grid)
+ shardingOption.grid = grid;
if (shardingOption.shardingArray[loopIdx].empty())
- shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
- meshAxes.end());
+ shardingOption.shardingArray[loopIdx].append(gridAxes.begin(),
+ gridAxes.end());
return success();
}
} // namespace
FailureOr<ShardingOption>
-mesh::detail::defaultGetShardingOption(Operation *op,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings) {
+shard::detail::defaultGetShardingOption(Operation *op,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings) {
ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
ShardingOption shardingOption;
@@ -276,25 +274,25 @@ mesh::detail::defaultGetShardingOption(Operation *op,
// 1. Fill sharding option based on op results
for (auto shardingIt : llvm::enumerate(resultShardings)) {
- MeshSharding shardAttr = shardingIt.value();
+ Sharding shardAttr = shardingIt.value();
if (!shardAttr)
continue;
AffineMap map = maps[numOperands + shardingIt.index()];
anyShardingInResultsOrOperands = true;
if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) {
- shardingOption.mesh = shardAttr.getMeshAttr();
+ shardingOption.grid = shardAttr.getGridAttr();
} else {
// Handle the split axes: calculate the corresponding loop index for each
// split axes sub-array, and then store the sub-array to
// shardingOption[index]
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
AffineExpr expr = std::get<0>(it);
- ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
+ ArrayRef<GridAxis> axes = std::get<1>(it).asArrayRef();
auto dim = cast<AffineDimExpr>(expr);
unsigned index = dim.getPosition();
visitedLoopIndices.insert(index);
if (failed(fillShardingOption(op, shardingOption,
- shardAttr.getMeshAttr(), axes, index)))
+ shardAttr.getGridAttr(), axes, index)))
return failure();
}
}
@@ -302,7 +300,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
// 2. Fill sharding option based on operands
for (auto shardingIt : llvm::enumerate(operandShardings)) {
- MeshSharding shardAttr = shardingIt.value();
+ Sharding shardAttr = shardingIt.value();
if (!shardAttr)
continue;
@@ -316,7 +314,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
// then the operands with multiple loop indices.
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
AffineExpr expr = std::get<0>(it);
- ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
+ ArrayRef<GridAxis> axes = std::get<1>(it).asArrayRef();
FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
checkOperandAffineExpr(expr, numDims);
if (failed(loopIndices))
@@ -329,7 +327,7 @@ mesh::detail::defaultGetShardingOption(Operation *op,
unsigned loopIdx = *loopIndices->begin();
visitedLoopIndices.insert(loopIdx);
if (failed(fillShardingOption(op, shardingOption,
- shardAttr.getMeshAttr(), axes, loopIdx)))
+ shardAttr.getGridAttr(), axes, loopIdx)))
return failure();
}
// If multiple loop indices correspond to a dimension of an operand, it is
@@ -361,11 +359,11 @@ mesh::detail::defaultGetShardingOption(Operation *op,
}
// Get the sharding attributed for the given result and sharding option.
-MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
- AffineMap map,
- ArrayRef<utils::IteratorType> loopTypes) {
+static Sharding getSharding(OpResult result,
+ const ShardingOption &shardingOption, AffineMap map,
+ ArrayRef<utils::IteratorType> loopTypes) {
auto resultType = cast<RankedTensorType>(result.getType());
- SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
+ SmallVector<SmallVector<GridAxis>> splitAxes(resultType.getRank());
// process the split axes
for (auto it : llvm::enumerate(map.getResults())) {
@@ -379,25 +377,25 @@ MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption,
}
removeTrailingEmptySubArray(splitAxes);
- return MeshSharding::get(shardingOption.mesh,
- fromArrayOfVector(result.getContext(), splitAxes));
+ return Sharding::get(shardingOption.grid,
+ fromArrayOfVector(result.getContext(), splitAxes));
}
-static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
- const ShardingOption &shardingOption,
- AffineMap map) {
+static FailureOr<Sharding> getSharding(OpOperand &opOperand,
+ const ShardingOption &shardingOption,
+ AffineMap map) {
Value operandValue = opOperand.get();
auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
if (!operandType) {
if (operandValue.getType().isIntOrIndexOrFloat())
- return MeshSharding();
+ return Sharding();
return failure();
}
// 0d tensors cannot be sharded and must get replicated
if (operandType.getRank() == 0) {
- return MeshSharding(shardingOption.mesh);
+ return Sharding(shardingOption.grid);
}
- SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
+ SmallVector<SmallVector<GridAxis>> splitAxes(operandType.getRank());
unsigned numDims = map.getNumDims();
for (auto it : llvm::enumerate(map.getResults())) {
int64_t idx = it.index();
@@ -422,15 +420,14 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
}
removeTrailingEmptySubArray(splitAxes);
- return MeshSharding::get(
- shardingOption.mesh,
+ return Sharding::get(
+ shardingOption.grid,
fromArrayOfVector(opOperand.get().getContext(), splitAxes));
}
-FailureOr<std::vector<MeshSharding>>
-mesh::detail::defaultGetShardingAnnotations(
+FailureOr<std::vector<Sharding>> shard::detail::defaultGetShardingAnnotations(
Operation *op, const ShardingOption &shardingOption) {
- std::vector<MeshSharding> res;
+ std::vector<Sharding> res;
ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
SmallVector<utils::IteratorType> loopTypes =
@@ -439,7 +436,7 @@ mesh::detail::defaultGetShardingAnnotations(
unsigned numOperands = op->getNumOperands();
for (OpOperand &opOperand : op->getOpOperands()) {
- FailureOr<MeshSharding> shardingAttr = getSharding(
+ FailureOr<Sharding> shardingAttr = ::getSharding(
opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
if (failed(shardingAttr))
return failure();
@@ -447,9 +444,9 @@ mesh::detail::defaultGetShardingAnnotations(
}
for (OpResult result : op->getResults()) {
- res.push_back(getSharding(result, shardingOption,
- maps[numOperands + result.getResultNumber()],
- loopTypes));
+ res.push_back(::getSharding(result, shardingOption,
+ maps[numOperands + result.getResultNumber()],
+ loopTypes));
}
return res;
@@ -459,26 +456,25 @@ mesh::detail::defaultGetShardingAnnotations(
// detail::defaultAddShardingAnnotations
//===----------------------------------------------------------------------===//
-// To add a `mesh.shard` op for the given result, based on the details provided
+// To add a `shard.shard` op for the given result, based on the details provided
// in `shardingOption`, `map`, and `loopTypes`.
static LogicalResult addShardOp(OpBuilder &b, OpResult result,
const ShardingOption &shardingOption,
AffineMap map,
ArrayRef<utils::IteratorType> loopTypes) {
- MeshSharding sharding = getSharding(result, shardingOption, map, loopTypes);
+ Sharding sharding = getSharding(result, shardingOption, map, loopTypes);
maybeInsertTargetShardingAnnotation(sharding, result, b);
return success();
}
-// To add a `mesh.shard` op for the given operand, based on the details provided
-// in `shardingOption`, `map`, and `loopTypes`.
+// To add a `shard.shard` op for the given operand, based on the details
+// provided in `shardingOption`, `map`, and `loopTypes`.
static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
const ShardingOption &shardingOption,
AffineMap map) {
- FailureOr<MeshSharding> sharding =
- getSharding(opOperand, shardingOption, map);
+ FailureOr<Sharding> sharding = getSharding(opOperand, shardingOption, map);
if (failed(sharding)) {
return failure();
}
@@ -488,9 +484,9 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
return success();
}
-LogicalResult mesh::detail::defaultAddShardingAnnotations(
+LogicalResult shard::detail::defaultAddShardingAnnotations(
Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
- assert(!shardingOption.empty && shardingOption.mesh);
+ assert(!shardingOption.empty && shardingOption.grid);
ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
SmallVector<utils::IteratorType> loopTypes =
@@ -498,7 +494,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
unsigned numOperands = op->getNumOperands();
- // 1. add mesh.shard ops for all op results
+ // 1. add shard.shard ops for all op results
for (OpResult result : op->getResults()) {
if (failed(addShardOp(b, result, shardingOption,
maps[numOperands + result.getResultNumber()],
@@ -506,7 +502,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
return failure();
}
- // 2. add mesh.shard ops for all operands
+ // 2. add shard.shard ops for all operands
for (OpOperand &opOperand : op->getOpOperands()) {
if (failed(addShardOp(b, opOperand, shardingOption,
maps[opOperand.getOperandNumber()])))
@@ -517,9 +513,8 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
}
#ifndef NDEBUG
-static bool
-isValueCompatibleWithFullReplicationSharding(Value value,
- MeshSharding sharding) {
+static bool isValueCompatibleWithFullReplicationSharding(Value value,
+ Sharding sharding) {
if (isa<RankedTensorType>(value.getType())) {
return isFullReplication(sharding);
}
@@ -527,60 +522,59 @@ isValueCompatibleWithFullReplicationSharding(Value value,
return !sharding;
}
-template <typename ValueRange, typename MeshShardingRage>
+template <typename ValueRange, typename ShardingRage>
static bool
areValuesCompatibleWithFullReplicationShardings(ValueRange &&values,
- MeshShardingRage &&shardings) {
+ ShardingRage &&shardings) {
if (std::size(values) != std::size(shardings)) {
return false;
}
- return llvm::all_of(
- llvm::zip_equal(std::forward<ValueRange>(values),
- std::forward<MeshShardingRage>(shardings)),
- [](auto valueAndSharding) {
- return isValueCompatibleWithFullReplicationSharding(
- std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
- });
+ return llvm::all_of(llvm::zip_equal(std::forward<ValueRange>(values),
+ std::forward<ShardingRage>(shardings)),
+ [](auto valueAndSharding) {
+ return isValueCompatibleWithFullReplicationSharding(
+ std::get<0>(valueAndSharding),
+ std::get<1>(valueAndSharding));
+ });
}
#endif // NDEBUG
-void mesh::spmdizeFullyReplicatedOperation(
- Operation &op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable, OpBuilder &builder) {
- assert(spmdizedOperands.size() == operandShardings.size());
+void shard::partitionFullyReplicatedOperation(
+ Operation &op, ArrayRef<Value> partitiondOperands,
+ ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap, SymbolTableCollection &symbolTable,
+ OpBuilder &builder) {
+ assert(partitiondOperands.size() == operandShardings.size());
assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(),
operandShardings));
assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(),
resultShardings));
// `clone` will populate the mapping of old to new results.
- builder.clone(op, spmdizationMap);
+ builder.clone(op, partitionMap);
}
-static void updateMeshAxisAssignmentForLoopIterators(
- ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
- SmallVector<std::optional<SmallVector<MeshAxis>>>
- &meshAxesAssignmentForLoopIterators) {
+static void updateGridAxisAssignmentForLoopIterators(
+ ArrayRef<GridAxis> gridAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
+ SmallVector<std::optional<SmallVector<GridAxis>>>
+ &gridAxesAssignmentForLoopIterators) {
AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
unsigned loopIteratorIdx = affineDimExpr.getPosition();
- if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
- assert(llvm::equal(meshAxesAssignmentForTensorAxis,
- *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
+ if (gridAxesAssignmentForLoopIterators[loopIteratorIdx]) {
+ assert(llvm::equal(gridAxesAssignmentForTensorAxis,
+ *gridAxesAssignmentForLoopIterators[loopIteratorIdx]));
} else {
- meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
- llvm::to_vector(meshAxesAssignmentForTensorAxis);
+ gridAxesAssignmentForLoopIterators[loopIteratorIdx] =
+ llvm::to_vector(gridAxesAssignmentForTensorAxis);
}
}
-ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
+ShardingArray shard::getGridAxisAssignmentForLoopIterators(
+ ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
ArrayRef<utils::IteratorType> loopIteratorTypes,
ArrayRef<AffineMap> indexingMaps) {
- SmallVector<std::optional<SmallVector<MeshAxis>>>
- meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
- std::vector<MeshSharding> operatorAndResultShardings;
+ SmallVector<std::optional<SmallVector<GridAxis>>>
+ gridAxisAssignmentForLoopIterators(loopIteratorTypes.size());
+ std::vector<Sharding> operatorAndResultShardings;
operatorAndResultShardings.reserve(operandShardings.size() +
resultShardings.size());
llvm::append_range(operatorAndResultShardings, operandShardings);
@@ -589,69 +583,69 @@ ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
if (!sharding) {
continue;
}
- for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
+ for (auto [gridAxesAssignmentForTensorAxis, indexingExpr] :
llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
- updateMeshAxisAssignmentForLoopIterators(
- meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
- meshAxisAssignmentForLoopIterators);
+ updateGridAxisAssignmentForLoopIterators(
+ gridAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
+ gridAxisAssignmentForLoopIterators);
}
// Missing trailing split axes means replication on those tensor dimensions.
for (unsigned i = sharding.getSplitAxes().size();
i < affineMap.getNumResults(); ++i) {
- updateMeshAxisAssignmentForLoopIterators(
- {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
+ updateGridAxisAssignmentForLoopIterators(
+ {}, affineMap.getResults()[i], gridAxisAssignmentForLoopIterators);
}
}
ShardingArray res;
- llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
- [](std::optional<SmallVector<MeshAxis>> &axes) {
+ llvm::transform(gridAxisAssignmentForLoopIterators, std::back_inserter(res),
+ [](std::optional<SmallVector<GridAxis>> &axes) {
if (!axes) {
- return SmallVector<MeshAxis>();
+ return SmallVector<GridAxis>();
};
return std::move(*axes);
});
return res;
}
-bool mesh::isAtLeastOneReductionIteratorSharded(
+bool shard::isAtLeastOneReductionIteratorSharded(
ArrayRef<utils::IteratorType> loopIteratorTypes,
- ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
- for (auto [loopIteratorType, meshAxisAssignment] :
- llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators) {
+ for (auto [loopIteratorType, gridAxisAssignment] :
+ llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
if (loopIteratorType == utils::IteratorType::reduction &&
- !meshAxisAssignment.empty()) {
+ !gridAxisAssignment.empty()) {
return true;
}
}
return false;
}
-SmallVector<MeshAxis> mesh::getReductionMeshAxes(
+SmallVector<GridAxis> shard::getReductionGridAxes(
ArrayRef<utils::IteratorType> loopIteratorTypes,
- ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
- SmallVector<MeshAxis> meshAxes;
- for (auto [loopIteratorType, meshAxisAssignment] :
- llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators) {
+ SmallVector<GridAxis> gridAxes;
+ for (auto [loopIteratorType, gridAxisAssignment] :
+ llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
if (loopIteratorType == utils::IteratorType::reduction) {
- llvm::append_range(meshAxes, meshAxisAssignment);
+ llvm::append_range(gridAxes, gridAxisAssignment);
}
}
- return meshAxes;
+ return gridAxes;
}
-void mesh::spmdizeTriviallyShardableOperation(
- Operation &op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable, OpBuilder &builder) {
+void shard::partitionTriviallyShardableOperation(
+ Operation &op, ArrayRef<Value> partitiondOperands,
+ ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap, SymbolTableCollection &symbolTable,
+ OpBuilder &builder) {
// `clone` will populate the mapping of old to new results.
- Operation *newOp = builder.clone(op, spmdizationMap);
+ Operation *newOp = builder.clone(op, partitionMap);
// Set the result types to the sharded counterparts.
for (auto [oldResult, newResult, sharding] :
llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
newResult.setType(shardType(
newResult.getType(),
- getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
+ getGridOrNull(&op, sharding.getGridAttr(), symbolTable), sharding));
}
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt
similarity index 73%
rename from mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
rename to mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt
index 381bc9afede07..a884764e70e92 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt
@@ -1,14 +1,14 @@
-add_mlir_dialect_library(MLIRMeshTransforms
+add_mlir_dialect_library(MLIRShardTransforms
Simplifications.cpp
ShardingPropagation.cpp
- Spmdization.cpp
+ Partition.cpp
Transforms.cpp
ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shard
DEPENDS
- MLIRMeshPassIncGen
+ MLIRShardPassIncGen
MLIRShardingInterface
LINK_LIBS PUBLIC
@@ -21,7 +21,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
MLIRFuncDialect
MLIRFunctionInterfaces
MLIRIR
- MLIRMeshDialect
+ MLIRShardDialect
MLIRPass
MLIRSupport
MLIRTensorDialect
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
similarity index 66%
rename from mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
rename to mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index c137f525296e3..03ff1e5bf061e 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -1,4 +1,4 @@
-//===- Spmdization.cpp --------------------------------------------- C++ --===//
+//===- Partition.cpp --------------------------------------------- C++ --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,11 +6,11 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
+#include "mlir/Dialect/Shard/Transforms/Partition.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -36,7 +36,7 @@
#include <tuple>
#include <type_traits>
-namespace mlir::mesh {
+namespace mlir::shard {
template <typename SourceAxes, typename TargetAxes>
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
@@ -46,52 +46,51 @@ static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
});
}
-static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx,
- MeshSharding sourceSharding,
- int64_t splitTensorAxis,
- MeshAxis splitMeshAxis) {
- SmallVector<MeshAxesAttr> targetShardingSplitAxes =
+static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx,
+ Sharding sourceSharding,
+ int64_t splitTensorAxis,
+ GridAxis splitGridAxis) {
+ SmallVector<GridAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
splitTensorAxis) {
- targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
+ targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
}
auto targetSplitAxes =
llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
- targetSplitAxes.push_back(splitMeshAxis);
+ targetSplitAxes.push_back(splitGridAxis);
targetShardingSplitAxes[splitTensorAxis] =
- MeshAxesAttr::get(ctx, targetSplitAxes);
- return MeshSharding::get(sourceSharding.getMeshAttr(),
- targetShardingSplitAxes);
+ GridAxesAttr::get(ctx, targetSplitAxes);
+ return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
}
-// Split a replicated tensor along a mesh axis.
+// Split a replicated tensor along a grid axis.
// E.g. [[0, 1]] -> [[0, 1, 2]].
-// Returns the spmdized target value with its sharding.
-static std::tuple<TypedValue<ShapedType>, MeshSharding>
+// Returns the partitiond target value with its sharding.
+static std::tuple<TypedValue<ShapedType>, Sharding>
splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
- MeshSharding sourceSharding,
- TypedValue<ShapedType> sourceShard, MeshOp mesh,
- int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+ Sharding sourceSharding,
+ TypedValue<ShapedType> sourceShard, GridOp grid,
+ int64_t splitTensorAxis, GridAxis splitGridAxis) {
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
builder
- .create<AllSliceOp>(sourceShard, mesh,
- ArrayRef<MeshAxis>(splitMeshAxis),
+ .create<AllSliceOp>(sourceShard, grid,
+ ArrayRef<GridAxis>(splitGridAxis),
splitTensorAxis)
.getResult());
- MeshSharding targetSharding = targetShardingInSplitLastAxis(
- builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
+ Sharding targetSharding = targetShardingInSplitLastAxis(
+ builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
return {targetShard, targetSharding};
}
// Detect if the resharding is of type e.g.
// [[0, 1]] -> [[0, 1, 2]].
-// If detected, returns the corresponding tensor axis mesh axis pair.
+// If detected, returns the corresponding tensor axis grid axis pair.
// Does not detect insertions like
// [[0, 1]] -> [[0, 2, 1]].
-static std::optional<std::tuple<int64_t, MeshAxis>>
-detectSplitLastAxisInResharding(MeshSharding sourceSharding,
- MeshSharding targetSharding) {
+static std::optional<std::tuple<int64_t, GridAxis>>
+detectSplitLastAxisInResharding(Sharding sourceSharding,
+ Sharding targetSharding) {
for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
++tensorAxis) {
if (sourceSharding.getSplitAxes().size() > tensorAxis) {
@@ -121,16 +120,15 @@ detectSplitLastAxisInResharding(MeshSharding sourceSharding,
return std::nullopt;
}
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
-trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshSharding sourceSharding,
- MeshSharding targetSharding,
+static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+ Sharding sourceSharding, Sharding targetSharding,
TypedValue<ShapedType> sourceShard) {
if (auto detectRes =
detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
- auto [tensorAxis, meshAxis] = detectRes.value();
- return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
- tensorAxis, meshAxis);
+ auto [tensorAxis, gridAxis] = detectRes.value();
+ return splitLastAxisInResharding(builder, sourceSharding, sourceShard, grid,
+ tensorAxis, gridAxis);
}
return std::nullopt;
@@ -138,10 +136,10 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
// Detect if the resharding is of type e.g.
// [[0, 1, 2]] -> [[0, 1]].
-// If detected, returns the corresponding tensor axis mesh axis pair.
-static std::optional<std::tuple<int64_t, MeshAxis>>
-detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
- MeshSharding targetSharding) {
+// If detected, returns the corresponding tensor axis grid axis pair.
+static std::optional<std::tuple<int64_t, GridAxis>>
+detectUnsplitLastAxisInResharding(Sharding sourceSharding,
+ Sharding targetSharding) {
for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
++tensorAxis) {
if (targetSharding.getSplitAxes().size() > tensorAxis) {
@@ -168,10 +166,10 @@ detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
return std::nullopt;
}
-static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
- MeshSharding sourceSharding,
- int64_t splitTensorAxis) {
- SmallVector<MeshAxesAttr> targetShardingSplitAxes =
+static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
+ Sharding sourceSharding,
+ int64_t splitTensorAxis) {
+ SmallVector<GridAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
splitTensorAxis);
@@ -180,9 +178,8 @@ static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx,
targetSplitAxes.pop_back();
targetShardingSplitAxes[splitTensorAxis] =
- MeshAxesAttr::get(ctx, targetSplitAxes);
- return MeshSharding::get(sourceSharding.getMeshAttr(),
- targetShardingSplitAxes);
+ GridAxesAttr::get(ctx, targetSplitAxes);
+ return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
}
static ShapedType allGatherResultShapeInUnsplitLastAxis(
@@ -193,45 +190,42 @@ static ShapedType allGatherResultShapeInUnsplitLastAxis(
return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
}
-static std::tuple<TypedValue<ShapedType>, MeshSharding>
-unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
- MeshSharding sourceSharding,
- ShapedType sourceUnshardedShape,
- TypedValue<ShapedType> sourceShard, MeshOp mesh,
- int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
+ ImplicitLocOpBuilder &builder, Sharding sourceSharding,
+ ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard,
+ GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) {
MLIRContext *ctx = builder.getContext();
builder.setInsertionPointAfterValue(sourceShard);
- MeshSharding targetSharding =
+ Sharding targetSharding =
targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
- sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
+ sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);
Value allGatherResult = AllGatherOp::create(
builder,
RankedTensorType::get(allGatherResultShape.getShape(),
allGatherResultShape.getElementType()),
- mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
+ grid.getSymName(), SmallVector<GridAxis>({splitGridAxis}), sourceShard,
APInt(64, splitTensorAxis));
ShapedType targetShape =
- shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+ shardShapedType(sourceUnshardedShape, grid, targetSharding);
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
tensor::CastOp::create(builder, targetShape, allGatherResult)
.getResult());
return {targetShard, targetSharding};
}
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
-tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshSharding sourceSharding,
- MeshSharding targetSharding,
+static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+ Sharding sourceSharding, Sharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard) {
if (auto detectRes =
detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
- auto [tensorAxis, meshAxis] = detectRes.value();
+ auto [tensorAxis, gridAxis] = detectRes.value();
return unsplitLastAxisInResharding(builder, sourceSharding,
- sourceUnshardedShape, sourceShard, mesh,
- tensorAxis, meshAxis);
+ sourceUnshardedShape, sourceShard, grid,
+ tensorAxis, gridAxis);
}
return std::nullopt;
@@ -241,10 +235,10 @@ tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
// [[0, 1], [2]] -> [[0], [1, 2]].
// Only moving the last axis counts.
// If detected, returns the corresponding (source_tensor_axis,
-// target_tensor_axis, mesh_axis) tuple.
-static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
-detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
- MeshSharding targetSharding) {
+// target_tensor_axis, grid_axis) tuple.
+static std::optional<std::tuple<int64_t, int64_t, GridAxis>>
+detectMoveLastSplitAxisInResharding(Sharding sourceSharding,
+ Sharding targetSharding) {
for (size_t sourceTensorAxis = 0;
sourceTensorAxis < sourceSharding.getSplitAxes().size();
++sourceTensorAxis) {
@@ -284,33 +278,32 @@ detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
return std::nullopt;
}
-static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx,
- MeshSharding sourceSharding,
- int64_t sourceTensorAxis,
- int64_t targetTensorAxis) {
- SmallVector<MeshAxesAttr> targetShardingSplitAxes =
+static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx,
+ Sharding sourceSharding,
+ int64_t sourceTensorAxis,
+ int64_t targetTensorAxis) {
+ SmallVector<GridAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
targetTensorAxis) {
- targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
+ targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {}));
}
auto sourceSplitAxes =
llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
assert(!sourceSplitAxes.empty());
- auto meshAxis = sourceSplitAxes.back();
+ auto gridAxis = sourceSplitAxes.back();
sourceSplitAxes.pop_back();
targetShardingSplitAxes[sourceTensorAxis] =
- MeshAxesAttr::get(ctx, sourceSplitAxes);
+ GridAxesAttr::get(ctx, sourceSplitAxes);
auto targetSplitAxes =
llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
- targetSplitAxes.push_back(meshAxis);
+ targetSplitAxes.push_back(gridAxis);
targetShardingSplitAxes[targetTensorAxis] =
- MeshAxesAttr::get(ctx, targetSplitAxes);
+ GridAxesAttr::get(ctx, targetSplitAxes);
- return MeshSharding::get(sourceSharding.getMeshAttr(),
- targetShardingSplitAxes);
+ return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes);
}
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
@@ -325,46 +318,46 @@ static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
}
-static std::tuple<TypedValue<ShapedType>, MeshSharding>
-moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshSharding sourceSharding,
+static std::tuple<TypedValue<ShapedType>, Sharding>
+moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+ Sharding sourceSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard,
int64_t sourceTensorAxis,
- int64_t targetTensorAxis, MeshAxis meshAxis) {
+ int64_t targetTensorAxis, GridAxis gridAxis) {
MLIRContext *ctx = builder.getContext();
builder.setInsertionPointAfterValue(sourceShard);
- MeshSharding targetSharding = targetShardingInMoveLastAxis(
+ Sharding targetSharding = targetShardingInMoveLastAxis(
ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
- sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
+ sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis,
targetTensorAxis);
Value allToAllResult = AllToAllOp::create(
builder,
RankedTensorType::get(allToAllResultShape.getShape(),
allToAllResultShape.getElementType()),
- mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
+ grid.getSymName(), SmallVector<GridAxis>({gridAxis}), sourceShard,
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
ShapedType targetShape =
- shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+ shardShapedType(sourceUnshardedShape, grid, targetSharding);
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
return {targetShard, targetSharding};
}
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
-tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshSharding sourceSharding,
- MeshSharding targetSharding,
+static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+ Sharding sourceSharding,
+ Sharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard) {
if (auto detectRes =
detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
- auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
+ auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value();
return moveLastSplitAxisInResharding(
- builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
- sourceTensorAxis, targetTensorAxis, meshAxis);
+ builder, grid, sourceSharding, sourceUnshardedShape, sourceShard,
+ sourceTensorAxis, targetTensorAxis, gridAxis);
}
return std::nullopt;
@@ -374,10 +367,9 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
// needed. A changed halo sizes requires copying the "core" of the source tensor
// into the "core" of the destination tensor followed by an update halo
// operation.
-static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>>
-tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshSharding sourceSharding,
- MeshSharding targetSharding,
+static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
+tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
+ Sharding sourceSharding, Sharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue<ShapedType> sourceShard) {
// Currently handles only cases where halo sizes differ but everything else
@@ -395,7 +387,7 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
ShapedType::isStaticShape(tgtHaloSizes) &&
sourceShard.getType().hasStaticShape()) &&
- "dynamic shapes/halos are not supported yet for mesh-spmdization");
+ "dynamic shapes/halos are not supported yet for shard-partition");
auto rank = sourceShard.getType().getRank();
auto splitAxes = sourceSharding.getSplitAxes();
SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0),
@@ -436,8 +428,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
sourceShard.getLoc(),
RankedTensorType::get(outShape,
sourceShard.getType().getElementType()),
- initOprnd, mesh.getSymName(),
- MeshAxesArrayAttr::get(builder.getContext(),
+ initOprnd, grid.getSymName(),
+ GridAxesArrayAttr::get(builder.getContext(),
sourceSharding.getSplitAxes()),
targetSharding.getDynamicHaloSizes(),
targetSharding.getStaticHaloSizes())
@@ -446,41 +438,41 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
targetSharding);
}
-// Handles only resharding on a 1D mesh.
+// Handles only resharding on a 1D shard.
// Currently the sharded tensor axes must be exactly divisible by the single
-// mesh axis size.
+// grid axis size.
static TypedValue<ShapedType>
-reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshSharding sourceSharding, MeshSharding targetSharding,
+reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid,
+ Sharding sourceSharding, Sharding targetSharding,
TypedValue<ShapedType> sourceUnshardedValue,
TypedValue<ShapedType> sourceShard) {
assert(sourceShard.getType() ==
- shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
+ shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
[[maybe_unused]] ShapedType targetShardType =
- shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
+ shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding);
assert(sourceShard.getType().getRank() == targetShardType.getRank());
- assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
+ assert(grid.getRank() == 1 && "Only 1D grides are currently supported.");
if (sourceSharding == targetSharding) {
return sourceShard;
}
TypedValue<ShapedType> targetShard;
- MeshSharding actualTargetSharding;
+ Sharding actualTargetSharding;
if (sourceSharding.getStaticShardedDimsOffsets().empty() &&
targetSharding.getStaticShardedDimsOffsets().empty() &&
sourceSharding.getStaticHaloSizes().empty() &&
targetSharding.getStaticHaloSizes().empty()) {
if (auto tryRes = tryMoveLastSplitAxisInResharding(
- builder, mesh, sourceSharding, targetSharding,
+ builder, grid, sourceSharding, targetSharding,
sourceUnshardedValue.getType(), sourceShard)) {
std::tie(targetShard, actualTargetSharding) = tryRes.value();
} else if (auto tryRes =
- trySplitLastAxisInResharding(builder, mesh, sourceSharding,
+ trySplitLastAxisInResharding(builder, grid, sourceSharding,
targetSharding, sourceShard)) {
std::tie(targetShard, actualTargetSharding) = tryRes.value();
} else if (auto tryRes = tryUnsplitLastAxisInResharding(
- builder, mesh, sourceSharding, targetSharding,
+ builder, grid, sourceSharding, targetSharding,
sourceUnshardedValue.getType(), sourceShard)) {
std::tie(targetShard, actualTargetSharding) = tryRes.value();
}
@@ -491,9 +483,8 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
return targetShard;
}
-TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
- MeshSharding sourceSharding,
- MeshSharding targetSharding,
+TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, GridOp grid,
+ Sharding sourceSharding, Sharding targetSharding,
TypedValue<ShapedType> sourceUnshardedValue,
TypedValue<ShapedType> sourceShard) {
// If source and destination sharding are the same, no need to do anything.
@@ -503,28 +494,28 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
}
// Tries to handle the case where the resharding is needed because the halo
- // sizes are different. Supports arbitrary mesh dimensionality.
+ // sizes are different. Supports arbitrary grid dimensionality.
if (auto tryRes = tryUpdateHaloInResharding(
- builder, mesh, sourceSharding, targetSharding,
+ builder, grid, sourceSharding, targetSharding,
sourceUnshardedValue.getType(), sourceShard)) {
return std::get<0>(tryRes.value()); // targetShard
}
- // Resort to handling only 1D meshes since the general case is complicated if
+ // Resort to handling only 1D grids since the general case is complicated if
// it needs to be communication efficient in terms of minimizing the data
// transfered between devices.
- return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
+ return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding,
sourceUnshardedValue, sourceShard);
}
-TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
+TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
ShardOp target,
TypedValue<ShapedType> sourceShardValue) {
assert(source.getResult() == target.getSrc());
auto sourceSharding = source.getSharding();
auto targetSharding = target.getSharding();
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
- return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
+ return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
cast<TypedValue<ShapedType>>(source.getSrc()),
sourceShardValue);
}
@@ -533,21 +524,21 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
ShardOp target,
TypedValue<ShapedType> sourceShardValue,
SymbolTableCollection &symbolTableCollection) {
- MeshOp srcMesh = getMesh(source, symbolTableCollection);
- assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection));
- return reshard(builder, srcMesh, source, target, sourceShardValue);
+ GridOp srcGrid = getGrid(source, symbolTableCollection);
+ assert(srcGrid && srcGrid == getGrid(target, symbolTableCollection));
+ return reshard(builder, srcGrid, source, target, sourceShardValue);
}
void reshardingRegisterDependentDialects(DialectRegistry ®istry) {
- registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
+ registry.insert<shard::ShardDialect, tensor::TensorDialect>();
}
-#define GEN_PASS_DEF_SPMDIZATION
-#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+#define GEN_PASS_DEF_PARTITION
+#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
using UnshardedToShardedValueMap = DenseMap<Value, Value>;
-// Get the types of block arguments for an spmdized block.
+// Get the types of block arguments for an partitiond block.
// Reads the sharding annotations of the arguments to deduce the sharded types.
// Types that are not ranked tensors are left unchanged.
SmallVector<Type>
@@ -566,35 +557,36 @@ shardedBlockArgumentTypes(Block &block,
Operation *useOp = *rankedTensorArg.getUsers().begin();
ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
assert(shardOp);
- MeshOp mesh = getMesh(shardOp, symbolTableCollection);
- return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
+ GridOp grid = getGrid(shardOp, symbolTableCollection);
+ return cast<Type>(shardShapedType(rankedTensorArg.getType(), grid,
shardOp.getSharding()));
});
return res;
}
-static LogicalResult spmdizeOperation(
- Operation &op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTableCollection, OpBuilder &builder) {
+static LogicalResult
+partitionOperation(Operation &op, ArrayRef<Value> partitiondOperands,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings, IRMapping &partitionMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
if (!shardingInterface) {
// If there is no sharding interface we are conservative and assume that
// the op should be fully replicated no all devices.
- spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings,
- resultShardings, spmdizationMap,
- symbolTableCollection, builder);
+ partitionFullyReplicatedOperation(op, partitiondOperands, operandShardings,
+ resultShardings, partitionMap,
+ symbolTableCollection, builder);
} else {
- if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
- resultShardings, spmdizationMap,
- symbolTableCollection, builder))) {
+ if (failed(shardingInterface.partition(partitiondOperands, operandShardings,
+ resultShardings, partitionMap,
+ symbolTableCollection, builder))) {
return failure();
}
}
- assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
- return spmdizationMap.contains(result);
+ assert(llvm::all_of(op.getResults(), [&partitionMap](OpResult result) {
+ return partitionMap.contains(result);
}));
return success();
@@ -602,88 +594,88 @@ static LogicalResult spmdizeOperation(
// Retrieve the sharding annotations for the operands of the given operation.
// If the type is not a ranked tensor it is not require to have an annotation.
-static std::vector<MeshSharding> getOperandShardings(Operation &op) {
- std::vector<MeshSharding> res;
+static std::vector<Sharding> getOperandShardings(Operation &op) {
+ std::vector<Sharding> res;
res.reserve(op.getNumOperands());
llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(operand);
if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
- return MeshSharding();
+ return Sharding();
}
Operation *definingOp = operand.getDefiningOp();
assert(definingOp);
ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
- return MeshSharding(shardOp.getSharding());
+ return Sharding(shardOp.getSharding());
});
return res;
}
// Retrieve the sharding annotations for the results of the given operation.
// If the type is not a ranked tensor it is not require to have an annotation.
-static std::vector<MeshSharding> getResultShardings(Operation &op) {
- std::vector<MeshSharding> res;
+static std::vector<Sharding> getResultShardings(Operation &op) {
+ std::vector<Sharding> res;
res.reserve(op.getNumResults());
llvm::transform(
op.getResults(), std::back_inserter(res), [&op](OpResult result) {
if (!result.hasOneUse() || result.use_empty()) {
- return MeshSharding();
+ return Sharding();
}
TypedValue<RankedTensorType> rankedTensor =
dyn_cast<TypedValue<RankedTensorType>>(result);
if (!rankedTensor) {
- return MeshSharding();
+ return Sharding();
}
Operation *userOp = *result.getUsers().begin();
ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
if (shardOp) {
- return MeshSharding(shardOp.getSharding());
+ return Sharding(shardOp.getSharding());
}
if (rankedTensor.getType().getRank() == 0) {
// This is a 0d tensor result without explicit sharding.
- // Find mesh symbol from operands, if any.
- // Shardings without mesh are not always fully supported yet.
+ // Find grid symbol from operands, if any.
+ // Shardings without grid are not always fully supported yet.
for (auto operand : op.getOperands()) {
if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
- return MeshSharding(sharding.getMeshAttr());
+ return Sharding(sharding.getGridAttr());
}
}
}
- return MeshSharding();
+ return Sharding();
});
return res;
}
static LogicalResult
-spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTableCollection,
- OpBuilder &builder) {
- Value targetSpmdValue;
+partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
+ Value targetPartitionValue;
// Check if 2 shard ops are chained. If not there is no need for resharding
// as the source and target shared the same sharding.
ShardOp srcShardOp =
dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
if (!srcShardOp) {
- targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
+ targetPartitionValue = partitionMap.lookup(shardOp.getSrc());
} else {
// Insert resharding.
- TypedValue<ShapedType> srcSpmdValue =
- cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
- targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
- symbolTableCollection);
+ TypedValue<ShapedType> srcPartitionValue =
+ cast<TypedValue<ShapedType>>(partitionMap.lookup(srcShardOp));
+ targetPartitionValue = reshard(builder, srcShardOp, shardOp,
+ srcPartitionValue, symbolTableCollection);
}
- assert(!spmdizationMap.contains(shardOp.getResult()));
- spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
+ assert(!partitionMap.contains(shardOp.getResult()));
+ partitionMap.map(shardOp.getResult(), targetPartitionValue);
return success();
}
static LogicalResult
-spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTableCollection,
- OpBuilder &builder) {
+partitionOperation(Operation &op, IRMapping &partitionMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
if (isa<ShardingOp>(op)) {
return success();
}
@@ -693,30 +685,31 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
return op.emitError("expected a shard op as source of get_sharding");
}
auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp());
- spmdizationMap.map(op.getResult(0), newSharding->getResult(0));
+ partitionMap.map(op.getResult(0), newSharding->getResult(0));
return success();
}
ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
if (shardOp) {
- return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
- builder);
+ return partitionOperation(shardOp, partitionMap, symbolTableCollection,
+ builder);
}
- SmallVector<Value> spmdizedOperands;
- llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
- [&spmdizationMap](Value operand) {
- assert(spmdizationMap.contains(operand));
- return spmdizationMap.lookup(operand);
+ SmallVector<Value> partitiondOperands;
+ llvm::transform(op.getOperands(), std::back_inserter(partitiondOperands),
+ [&partitionMap](Value operand) {
+ assert(partitionMap.contains(operand));
+ return partitionMap.lookup(operand);
});
- return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op),
- getResultShardings(op), spmdizationMap,
- symbolTableCollection, builder);
+ return partitionOperation(op, partitiondOperands, getOperandShardings(op),
+ getResultShardings(op), partitionMap,
+ symbolTableCollection, builder);
}
-static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTableCollection,
- OpBuilder &builder) {
+static LogicalResult
+partitionBlock(Block &block, IRMapping &partitionMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
SmallVector<Location> argLocations;
llvm::transform(block.getArguments(), std::back_inserter(argLocations),
@@ -724,16 +717,16 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
Block *newBlock = builder.createBlock(
block.getParent(), {},
shardedBlockArgumentTypes(block, symbolTableCollection), argLocations);
- for (auto [unshardedBlockArg, spmdizedBlockArg] :
+ for (auto [unshardedBlockArg, partitiondBlockArg] :
llvm::zip(block.getArguments(), newBlock->getArguments())) {
- spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
+ partitionMap.map(unshardedBlockArg, partitiondBlockArg);
}
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPointToEnd(newBlock);
for (Operation &op : block.getOperations()) {
- if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
- builder))) {
+ if (failed(partitionOperation(op, partitionMap, symbolTableCollection,
+ builder))) {
return failure();
}
}
@@ -742,8 +735,8 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
}
static LogicalResult
-spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTableCollection) {
+partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap,
+ SymbolTableCollection &symbolTableCollection) {
OpBuilder builder(op.getFunctionBody());
// Snapshot the original blocks to not mess up the iteration when adding new
@@ -757,8 +750,8 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
}
for (Block *block : originalBlocks) {
- if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
- builder))) {
+ if (failed(partitionBlock(*block, partitionMap, symbolTableCollection,
+ builder))) {
return failure();
}
}
@@ -791,22 +784,22 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
namespace {
-struct Spmdization : public impl::SpmdizationBase<Spmdization> {
+struct Partition : public impl::PartitionBase<Partition> {
void runOnOperation() override {
- IRMapping spmdizationMap;
+ IRMapping partitionMap;
SymbolTableCollection symbolTableCollection;
- if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,
- symbolTableCollection))) {
+ if (failed(partitionFuncOp(getOperation(), partitionMap,
+ symbolTableCollection))) {
return signalPassFailure();
}
}
void getDependentDialects(DialectRegistry ®istry) const override {
reshardingRegisterDependentDialects(registry);
- registry.insert<mesh::MeshDialect>();
+ registry.insert<shard::ShardDialect>();
}
};
} // namespace
-} // namespace mlir::mesh
+} // namespace mlir::shard
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
similarity index 85%
rename from mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
rename to mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
index 09c754da7a6b7..17220c4256bf3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
@@ -6,11 +6,11 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/Transforms/Passes.h"
+#include "mlir/Dialect/Shard/Transforms/Passes.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/ADT/STLExtras.h"
@@ -21,17 +21,17 @@
#include <vector>
namespace mlir {
-namespace mesh {
+namespace shard {
#define GEN_PASS_DEF_SHARDINGPROPAGATION
-#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
-} // namespace mesh
+#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
+} // namespace shard
} // namespace mlir
#define DEBUG_TYPE "sharding-propagation"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
-using namespace mlir::mesh;
+using namespace mlir::shard;
enum class ReshardingRquirementKind {
NO_RESHARDING = 0,
@@ -68,7 +68,7 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
[[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
const ShardingOption &v) {
- return stream << "{empty = " << v.empty << ", mesh" << v.mesh
+ return stream << "{empty = " << v.empty << ", grid" << v.grid
<< ", shardingArray = " << v.shardingArray << "}";
}
@@ -105,15 +105,15 @@ operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) {
// specific shardings. For example, mustShardings = [shard0, None] and
// optionalShardings = [None, shard1], the result will be [[shard0, shard1],
// [shard0, None]]
-static SmallVector<std::vector<MeshSharding>>
-getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
- ArrayRef<MeshSharding> optionalShardings) {
- SmallVector<std::vector<MeshSharding>> allShardingAttrs;
- std::vector<MeshSharding> curShardingAttrs;
+static SmallVector<std::vector<Sharding>>
+getOrderedPossibleShardingAttrs(ArrayRef<Sharding> mustShardings,
+ ArrayRef<Sharding> optionalShardings) {
+ SmallVector<std::vector<Sharding>> allShardingAttrs;
+ std::vector<Sharding> curShardingAttrs;
std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) {
if (i == mustShardings.size()) {
- allShardingAttrs.push_back(std::vector<MeshSharding>(curShardingAttrs));
+ allShardingAttrs.push_back(std::vector<Sharding>(curShardingAttrs));
return;
}
@@ -147,14 +147,14 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings,
// 1. No resharding is required (all existing annotations are compatible).
// 2. No resharding for operands/results that have annotation specifically
// targeting this operation. This means
-// * operands that are the result of `mesh.shard` ops marked with
+// * operands that are the result of `shard.shard` ops marked with
// `annotate_for_users`.
-// * results that are annotated with `mesh.shard` ops without
+// * results that are annotated with `shard.shard` ops without
// `annotate_for_users`.
// 3. All other cases. Resharding is required for operands/results with
// annotation targeting explicitly this operation.
ReshardingRquirementKind getReshardingRquirementKind(
- Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) {
+ Operation *op, const std::vector<Sharding> &operandAndResultShardings) {
ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING;
size_t operandsCount = op->getOperands().size();
@@ -213,14 +213,13 @@ ReshardingRquirementKind getReshardingRquirementKind(
// 3. Resharding of existing explicit sharding annotations for this op.
static FailureOr<ShardingOption> selectShardingOption(
ShardingInterface shardingOp,
- ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs,
- ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) {
+ ArrayRef<std::vector<Sharding>> possibleOperandShardingAttrs,
+ ArrayRef<std::vector<Sharding>> possibleResultShardingAttrs) {
SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>>
shardingOptionsAndReshardingRequirements;
- for (ArrayRef<MeshSharding> resultShardings : possibleResultShardingAttrs) {
- for (ArrayRef<MeshSharding> operandShardings :
- possibleOperandShardingAttrs) {
+ for (ArrayRef<Sharding> resultShardings : possibleResultShardingAttrs) {
+ for (ArrayRef<Sharding> operandShardings : possibleOperandShardingAttrs) {
FailureOr<ShardingOption> shardingOption =
shardingOp.getShardingOption(operandShardings, resultShardings);
if (failed(shardingOption) || shardingOption->empty) {
@@ -231,7 +230,7 @@ static FailureOr<ShardingOption> selectShardingOption(
// They may be missing some annotations.
// Whatever is returned by getShardingAnnotations is exactly what the op
// needs.
- FailureOr<std::vector<MeshSharding>> operandAndResultShardings =
+ FailureOr<std::vector<Sharding>> operandAndResultShardings =
shardingOp.getShardingAnnotations(*shardingOption);
if (failed(operandAndResultShardings)) {
return failure();
@@ -276,13 +275,13 @@ static FailureOr<ShardingOption> selectShardingOption(
// For each operation that implements the ShardingInterface, infer the sharding
// option of the operation from its operands and/or results using the
// `getShardingOption` method. If the inferred sharding option is not empty, add
-// a `mesh.shard` operation for all remaining operands and results that do not
+// a `shard.shard` operation for all remaining operands and results that do not
// have sharding annotations.
static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
if (op->hasTrait<OpTrait::IsTerminator>() ||
(op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) ||
- llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op))
+ llvm::isa<shard::ShardOp, shard::ShardingOp, shard::GetShardingOp>(op))
return success();
if (!shardingOp) {
@@ -290,14 +289,13 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
return failure();
}
- // collect MeshSharding from results
- std::vector<MeshSharding> allowConflictsResultShardings;
+ // collect Sharding from results
+ std::vector<Sharding> allowConflictsResultShardings;
allowConflictsResultShardings.resize(op->getNumResults());
- std::vector<MeshSharding> resultMustShardings;
+ std::vector<Sharding> resultMustShardings;
resultMustShardings.resize(op->getNumResults());
for (OpResult result : op->getResults()) {
- FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
- getMeshSharding(result);
+ FailureOr<std::pair<bool, Sharding>> maybeShardAttr = getSharding(result);
if (failed(maybeShardAttr))
continue;
if (!maybeShardAttr->first)
@@ -307,14 +305,14 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
maybeShardAttr->second;
}
- // collect MeshSharding from operands
- std::vector<MeshSharding> allowConflictsOperandShardings;
+ // collect Sharding from operands
+ std::vector<Sharding> allowConflictsOperandShardings;
allowConflictsOperandShardings.resize(op->getNumOperands());
- std::vector<MeshSharding> operandMustShardings;
+ std::vector<Sharding> operandMustShardings;
operandMustShardings.resize(op->getNumOperands());
for (OpOperand &opOperand : op->getOpOperands()) {
- FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr =
- getMeshSharding(opOperand);
+ FailureOr<std::pair<bool, Sharding>> maybeShardAttr =
+ getSharding(opOperand);
if (failed(maybeShardAttr))
continue;
@@ -327,10 +325,10 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
}
// try to get the sharding option
- SmallVector<std::vector<MeshSharding>> possibleOperandShardingAttrs =
+ SmallVector<std::vector<Sharding>> possibleOperandShardingAttrs =
getOrderedPossibleShardingAttrs(operandMustShardings,
allowConflictsOperandShardings);
- SmallVector<std::vector<MeshSharding>> possibleResultShardingAttrs =
+ SmallVector<std::vector<Sharding>> possibleResultShardingAttrs =
getOrderedPossibleShardingAttrs(resultMustShardings,
allowConflictsResultShardings);
FailureOr<ShardingOption> shardingOption = selectShardingOption(
@@ -358,7 +356,7 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
// ShardingPropagation
//===----------------------------------------------------------------------===//
struct ShardingPropagation
- : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
+ : public shard::impl::ShardingPropagationBase<ShardingPropagation> {
using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
@@ -376,8 +374,7 @@ struct ShardingPropagation
LLVM_DEBUG(
DBGS() << "print all the ops' iterator types and indexing maps in the "
"block.\n";
- for (Operation &op
- : block.getOperations()) {
+ for (Operation &op : block.getOperations()) {
if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
});
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp
similarity index 66%
rename from mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
rename to mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp
index db5fd6e494da1..7b82dc5c613d3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp
@@ -1,4 +1,4 @@
-//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
+//===- Simplifications.cpp - Shard Simplifications -_------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
#include "TransformsDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
@@ -19,7 +19,7 @@
#include <numeric>
namespace mlir {
-namespace mesh {
+namespace shard {
void populateSimplificationPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
@@ -53,53 +53,53 @@ namespace {
// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
// symbol tables.
// We can't use DialectFoldInterface since the cache may be invalidated by some
-// pass changing the referenced MeshOp ops.
-struct MeshShapeFolder
- : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
+// pass changing the referenced GridOp ops.
+struct GridShapeFolder
+ : OpRewritePatternWithSymbolTableCollection<GridShapeOp> {
using OpRewritePatternWithSymbolTableCollection::
OpRewritePatternWithSymbolTableCollection;
- LogicalResult matchAndRewrite(MeshShapeOp op,
+ LogicalResult matchAndRewrite(GridShapeOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
- MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
- op.getOperation(), op.getMeshAttr());
- if (!mesh) {
+ GridOp grid = symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>(
+ op.getOperation(), op.getGridAttr());
+ if (!grid) {
return failure();
}
- ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
- SmallVector<MeshAxis> opAxesIota;
- if (opMeshAxes.empty()) {
- opAxesIota.resize(mesh.getRank());
+ ArrayRef<GridAxis> opGridAxes = op.getAxes();
+ SmallVector<GridAxis> opAxesIota;
+ if (opGridAxes.empty()) {
+ opAxesIota.resize(grid.getRank());
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
- opMeshAxes = opAxesIota;
+ opGridAxes = opAxesIota;
}
- if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
- return ShapedType::isDynamic(mesh.getShape()[axis]);
+ if (llvm::all_of(opGridAxes, [&grid](GridAxis axis) {
+ return ShapedType::isDynamic(grid.getShape()[axis]);
})) {
- // All mesh dimensions are dynamic. Nothing to fold.
+ // All grid dimensions are dynamic. Nothing to fold.
return failure();
}
SmallVector<Value> newResults(op->getResults().size());
- SmallVector<MeshAxis> newShapeOpMeshAxes;
+ SmallVector<GridAxis> newShapeOpGridAxes;
SmallVector<size_t> newToOldResultsIndexMap;
- for (size_t i = 0; i < opMeshAxes.size(); ++i) {
- auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
- if (ShapedType::isDynamic(meshAxisSize)) {
+ for (size_t i = 0; i < opGridAxes.size(); ++i) {
+ auto gridAxisSize = grid.getShape()[opGridAxes[i]];
+ if (ShapedType::isDynamic(gridAxisSize)) {
newToOldResultsIndexMap.push_back(i);
- newShapeOpMeshAxes.push_back(opMeshAxes[i]);
+ newShapeOpGridAxes.push_back(opGridAxes[i]);
} else {
- // Fold static mesh axes.
+ // Fold static grid axes.
newResults[i] = arith::ConstantOp::create(
- builder, builder.getIndexAttr(meshAxisSize));
+ builder, builder.getIndexAttr(gridAxisSize));
}
}
- // Leave only the dynamic mesh axes to be queried.
- if (!newShapeOpMeshAxes.empty()) {
- MeshShapeOp newShapeOp =
- MeshShapeOp::create(builder, mesh.getSymName(), newShapeOpMeshAxes);
+ // Leave only the dynamic grid axes to be queried.
+ if (!newShapeOpGridAxes.empty()) {
+ GridShapeOp newShapeOp =
+ GridShapeOp::create(builder, grid.getSymName(), newShapeOpGridAxes);
for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
}
@@ -114,8 +114,8 @@ struct MeshShapeFolder
void populateFoldingPatterns(RewritePatternSet &patterns,
SymbolTableCollection &symbolTableCollection) {
- patterns.add<MeshShapeFolder>(symbolTableCollection, patterns.getContext());
+ patterns.add<GridShapeFolder>(symbolTableCollection, patterns.getContext());
}
-} // namespace mesh
+} // namespace shard
} // namespace mlir
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
similarity index 78%
rename from mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
rename to mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
index 6ae95ae1f8a49..f1f8a3df92b6d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "TransformsDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
@@ -14,8 +14,8 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -29,12 +29,12 @@
#include <iterator>
#include <numeric>
-namespace mlir::mesh {
+namespace mlir::shard {
namespace {
-/// Lower `mesh.process_multi_index` into expression using
-/// `mesh.process_linear_index` and `mesh.mesh_shape`.
+/// Lower `shard.process_multi_index` into expression using
+/// `shard.process_linear_index` and `shard.grid_shape`.
struct ProcessMultiIndexOpLowering
: OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
using OpRewritePatternWithSymbolTableCollection::
@@ -42,30 +42,30 @@ struct ProcessMultiIndexOpLowering
LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
PatternRewriter &rewriter) const override {
- MeshOp mesh = getMesh(op, symbolTableCollection);
- if (!mesh) {
+ GridOp grid = getGrid(op, symbolTableCollection);
+ if (!grid) {
return failure();
}
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
builder.setInsertionPointAfter(op.getOperation());
- Value linearIndex = ProcessLinearIndexOp::create(builder, mesh);
- ValueRange meshShape = MeshShapeOp::create(builder, mesh).getResults();
+ Value linearIndex = ProcessLinearIndexOp::create(builder, grid);
+ ValueRange gridShape = GridShapeOp::create(builder, grid).getResults();
SmallVector<Value> completeMultiIndex =
affine::AffineDelinearizeIndexOp::create(builder, linearIndex,
- meshShape)
+ gridShape)
.getMultiIndex();
SmallVector<Value> multiIndex;
- ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
- SmallVector<MeshAxis> opAxesIota;
- if (opMeshAxes.empty()) {
- opAxesIota.resize(mesh.getRank());
+ ArrayRef<GridAxis> opGridAxes = op.getAxes();
+ SmallVector<GridAxis> opAxesIota;
+ if (opGridAxes.empty()) {
+ opAxesIota.resize(grid.getRank());
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
- opMeshAxes = opAxesIota;
+ opGridAxes = opAxesIota;
}
- llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
- [&completeMultiIndex](MeshAxis meshAxis) {
- return completeMultiIndex[meshAxis];
+ llvm::transform(opGridAxes, std::back_inserter(multiIndex),
+ [&completeMultiIndex](GridAxis gridAxis) {
+ return completeMultiIndex[gridAxis];
});
rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
return success();
@@ -87,15 +87,15 @@ struct AllSliceOpLowering
// axis.
// The slice axis is split into equisized parts with count
// the number of processes in the collective process group induced by
- // the mesh axes.
+ // the grid axes.
// The part for each process is determined by the corresponding
// linear-index in the process group.
//
// There are no collectives that require communication.
// Each process operates on its local tensor.
- MeshOp mesh = getMesh(op, symbolTableCollection);
- if (!mesh) {
+ GridOp grid = getGrid(op, symbolTableCollection);
+ if (!grid) {
return failure();
}
@@ -105,15 +105,15 @@ struct AllSliceOpLowering
Value zero = arith::ConstantOp::create(builder, builder.getIndexAttr(0));
Operation::result_range processInGroupMultiIndex =
- ProcessMultiIndexOp::create(builder, mesh.getSymName(),
- op.getMeshAxes())
+ ProcessMultiIndexOp::create(builder, grid.getSymName(),
+ op.getGridAxes())
.getResults();
Operation::result_range processGroupShape =
- MeshShapeOp::create(builder, mesh.getSymName(), op.getMeshAxes())
+ GridShapeOp::create(builder, grid.getSymName(), op.getGridAxes())
.getResult();
Value processGroupSize =
- createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
+ createCollectiveProcessGroupSize(grid, op.getGridAxes(), builder);
int64_t sliceAxis = op.getSliceAxis().getSExtValue();
Value operandSliceAxisSize =
@@ -126,7 +126,7 @@ struct AllSliceOpLowering
cf::AssertOp::create(builder, isTargetShapeExactlyDivisible,
"Slicing a tensor with axis size that is "
"not exactly divisible by the "
- "mesh process group size is not supported.");
+ "grid process group size is not supported.");
Value resultSliceAxisSize =
arith::DivUIOp::create(builder, operandSliceAxisSize, processGroupSize);
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
@@ -173,7 +173,7 @@ void populateProcessMultiIndexOpLoweringPatterns(
}
void registerProcessMultiIndexOpLoweringDialects(DialectRegistry ®istry) {
- registry.insert<affine::AffineDialect, mesh::MeshDialect>();
+ registry.insert<affine::AffineDialect, shard::ShardDialect>();
}
void populateAllSliceOpLoweringPatterns(
@@ -184,7 +184,7 @@ void populateAllSliceOpLoweringPatterns(
void registerAllSliceOpLoweringDialects(DialectRegistry ®istry) {
registry.insert<affine::AffineDialect, arith::ArithDialect,
- cf::ControlFlowDialect, mesh::MeshDialect,
+ cf::ControlFlowDialect, shard::ShardDialect,
tensor::TensorDialect>();
}
@@ -200,21 +200,21 @@ void registerAllOpLoweringDialects(DialectRegistry ®istry) {
}
TypedValue<IndexType>
-createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
ImplicitLocOpBuilder &builder) {
- Operation::result_range meshShape =
- mesh::MeshShapeOp::create(builder, mesh, axes).getResults();
+ Operation::result_range gridShape =
+ GridShapeOp::create(builder, grid, axes).getResults();
return cast<TypedValue<IndexType>>(arith::createProduct(
- builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
+ builder, builder.getLoc(), llvm::to_vector_of<Value>(gridShape),
builder.getIndexType()));
}
TypedValue<IndexType>
-createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
- ArrayRef<MeshAxis> meshAxes,
+createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
+ ArrayRef<GridAxis> gridAxes,
ImplicitLocOpBuilder &builder) {
Operation::result_range processGroupShape =
- MeshShapeOp::create(builder, mesh, meshAxes).getResult();
+ GridShapeOp::create(builder, grid, gridAxes).getResult();
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
@@ -226,11 +226,11 @@ createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
return cast<TypedValue<IndexType>>(res);
}
-TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
- ArrayRef<MeshAxis> meshAxes,
+TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
+ ArrayRef<GridAxis> gridAxes,
ImplicitLocOpBuilder &builder) {
return createProcessLinearIndex(
- mesh, ProcessMultiIndexOp::create(builder, mesh, meshAxes).getResults(),
- meshAxes, builder);
+ grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
+ gridAxes, builder);
}
-} // namespace mlir::mesh
+} // namespace mlir::shard
diff --git a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h b/mlir/lib/Dialect/Shard/Transforms/TransformsDetail.h
similarity index 82%
rename from mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
rename to mlir/lib/Dialect/Shard/Transforms/TransformsDetail.h
index 3e3f584caca24..60c9828ba736d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
+++ b/mlir/lib/Dialect/Shard/Transforms/TransformsDetail.h
@@ -6,14 +6,14 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
-#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
namespace mlir {
-namespace mesh {
+namespace shard {
template <typename Op>
struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern<Op> {
@@ -29,7 +29,7 @@ struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern<Op> {
SymbolTableCollection &symbolTableCollection;
};
-} // namespace mesh
+} // namespace shard
} // namespace mlir
-#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H
diff --git a/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
index 0421a6c0ff806..0784615b8edb8 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
-#include "mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h"
+#include "mlir/Dialect/Tensor/Extensions/ShardingExtensions.h"
using namespace mlir;
diff --git a/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
index dba59333666f6..8f0b7da1fd7b5 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
@@ -1,10 +1,10 @@
set(LLVM_OPTIONAL_SOURCES
AllExtensions.cpp
- MeshShardingExtensions.cpp
+ ShardingExtensions.cpp
)
-add_mlir_extension_library(MLIRTensorMeshShardingExtensions
- MeshShardingExtensions.cpp
+add_mlir_extension_library(MLIRTensorShardingExtensions
+ ShardingExtensions.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
@@ -22,5 +22,5 @@ add_mlir_extension_library(MLIRTensorAllExtensions
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
LINK_LIBS PUBLIC
- MLIRTensorMeshShardingExtensions
+ MLIRTensorShardingExtensions
)
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
similarity index 74%
rename from mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
rename to mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
index 26406ceef082c..75527d73deaaf 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
@@ -6,15 +6,15 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectRegistry.h"
using namespace mlir;
using namespace mlir::tensor;
-using namespace mlir::mesh;
+using namespace mlir::shard;
namespace {
@@ -40,20 +40,20 @@ struct CreatorOpShardingInterface
{AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
}
- LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
- IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable,
- OpBuilder &builder) const {
+ LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
assert(resultShardings.size() == 1);
auto resType = cast<RankedTensorType>(op->getResult(0).getType());
- mlir::mesh::MeshOp mesh;
+ mlir::shard::GridOp grid;
ShapedType shardType;
if (resType.getRank() > 0) {
- mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
+ grid = shard::getGrid(op, resultShardings[0].getGridAttr(), symbolTable);
shardType =
- cast<ShapedType>(mesh::shardType(resType, mesh, resultShardings[0]));
+ cast<ShapedType>(shard::shardType(resType, grid, resultShardings[0]));
} else {
shardType = resType;
}
@@ -67,7 +67,7 @@ struct CreatorOpShardingInterface
auto oldType = cast<ShapedType>(resType);
assert(oldType.getRank() == shardType.getRank());
int currOldOprndNum = -1;
- mesh::ShardShapeOp shapeForDevice;
+ shard::ShardShapeOp shapeForDevice;
ValueRange device;
Operation *newSharding = nullptr;
for (auto i = 0; i < oldType.getRank(); ++i) {
@@ -76,23 +76,23 @@ struct CreatorOpShardingInterface
newSharding =
builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
device =
- builder.create<mesh::ProcessMultiIndexOp>(op->getLoc(), mesh)
+ builder.create<shard::ProcessMultiIndexOp>(op->getLoc(), grid)
.getResults();
- shapeForDevice = builder.create<mesh::ShardShapeOp>(
- op->getLoc(), oldType.getShape(), spmdizedOperands,
+ shapeForDevice = builder.create<shard::ShardShapeOp>(
+ op->getLoc(), oldType.getShape(), partitiondOperands,
newSharding->getResult(0), device);
}
newOperands.emplace_back(shapeForDevice.getResult()[i]);
} else if (oldType.isDynamicDim(i)) {
assert(shardType.isDynamicDim(i));
- newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
+ newOperands.emplace_back(partitiondOperands[++currOldOprndNum]);
}
}
newOp = builder.create<OpTy>(op->getLoc(), shardType, newOperands);
- spmdizationMap.map(op->getResult(0), newOp->getResult(0));
+ partitionMap.map(op->getResult(0), newOp->getResult(0));
} else {
// `clone` will populate the mapping of old to new results.
- newOp = builder.clone(*op, spmdizationMap);
+ newOp = builder.clone(*op, partitionMap);
}
newOp->getResult(0).setType(shardType);
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index b1fac8c85a204..c6a438d348946 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -36,7 +36,7 @@ add_mlir_dialect_library(MLIRTosaShardingInterfaceImpl
LINK_LIBS PUBLIC
MLIRIR
- MLIRMeshDialect
+ MLIRShardDialect
MLIRShardingInterface
MLIRSupport
MLIRTosaDialect
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
index d3a5f44798106..45994a7ec679f 100644
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -7,9 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/DialectRegistry.h"
@@ -19,7 +19,7 @@
using namespace mlir;
using namespace mlir::tosa;
-using namespace mlir::mesh;
+using namespace mlir::shard;
namespace {
@@ -87,15 +87,15 @@ struct NegateOpSharding
return maps;
}
- LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
- ArrayRef<MeshSharding> operandShardings,
- ArrayRef<MeshSharding> resultShardings,
- IRMapping &spmdizationMap,
- SymbolTableCollection &symbolTable,
- OpBuilder &builder) const {
- spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
- resultShardings, spmdizationMap,
- symbolTable, builder);
+ LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands,
+ ArrayRef<Sharding> operandShardings,
+ ArrayRef<Sharding> resultShardings,
+ IRMapping &partitionMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ partitionTriviallyShardableOperation(*op, partitiondOperands,
+ operandShardings, resultShardings,
+ partitionMap, symbolTable, builder);
return success();
}
};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 648e508a9788f..ecd93ff4c6e7b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -13,8 +13,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
@@ -166,7 +166,7 @@ void TosaDialect::initialize() {
>();
addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
declarePromisedInterfaces<
- mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
+ shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
similarity index 90%
rename from mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
rename to mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index d54d0034da5be..5e20b5a59d927 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -1,14 +1,14 @@
-// RUN: mlir-opt %s -convert-mesh-to-mpi -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-shard-to-mpi -canonicalize -split-input-file | FileCheck %s
// -----
-// CHECK: mesh.mesh @mesh0
-mesh.mesh @mesh0(shape = 3x4x5)
+// CHECK: shard.grid @grid0
+shard.grid @grid0(shape = 3x4x5)
func.func @process_multi_index() -> (index, index, index) {
// CHECK: mpi.comm_rank
// CHECK-DAG: %[[v4:.*]] = arith.remsi
// CHECK-DAG: %[[v0:.*]] = arith.remsi
// CHECK-DAG: %[[v1:.*]] = arith.remsi
- %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+ %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
// CHECK: return %[[v1]], %[[v0]], %[[v4]] : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
@@ -17,7 +17,7 @@ func.func @process_multi_index() -> (index, index, index) {
func.func @process_linear_index() -> index {
// CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank
// CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index
- %0 = mesh.process_linear_index on @mesh0 : index
+ %0 = shard.process_linear_index on @grid0 : index
// CHECK: return %[[cast]] : index
return %0 : index
}
@@ -29,7 +29,7 @@ func.func @neighbors_dim0(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
%c4 = arith.constant 4 : index
// CHECK-DAG: [[up:%.*]] = arith.constant 44 : index
// CHECK-DAG: [[down:%.*]] = arith.constant 4 : index
- %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [0] : index, index
+ %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [0] : index, index
// CHECK: return [[down]], [[up]] : index, index
return %idx#0, %idx#1 : index, index
}
@@ -41,7 +41,7 @@ func.func @neighbors_dim1(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
%c4 = arith.constant 4 : index
// CHECK-DAG: [[up:%.*]] = arith.constant 29 : index
// CHECK-DAG: [[down:%.*]] = arith.constant -1 : index
- %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [1] : index, index
+ %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [1] : index, index
// CHECK: return [[down]], [[up]] : index, index
return %idx#0, %idx#1 : index, index
}
@@ -53,20 +53,20 @@ func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
%c4 = arith.constant 4 : index
// CHECK-DAG: [[up:%.*]] = arith.constant -1 : index
// CHECK-DAG: [[down:%.*]] = arith.constant 23 : index
- %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [2] : index, index
+ %idx:2 = shard.neighbors_linear_indices on @grid0[%c1, %c0, %c4] split_axes = [2] : index, index
// CHECK: return [[down]], [[up]] : index, index
return %idx#0, %idx#1 : index, index
}
// -----
-// CHECK: mesh.mesh @mesh0
+// CHECK: shard.grid @grid0
module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
- mesh.mesh @mesh0(shape = 3x4x5)
+ shard.grid @grid0(shape = 3x4x5)
func.func @process_multi_index() -> (index, index, index) {
// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
- %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+ %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
// CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
@@ -74,7 +74,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// CHECK-LABEL: func @process_linear_index
func.func @process_linear_index() -> index {
// CHECK: %[[c24:.*]] = arith.constant 24 : index
- %0 = mesh.process_linear_index on @mesh0 : index
+ %0 = shard.process_linear_index on @grid0 : index
// CHECK: return %[[c24]] : index
return %0 : index
}
@@ -82,7 +82,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// -----
module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
- mesh.mesh @mesh0(shape = 3x4x5)
+ shard.grid @grid0(shape = 3x4x5)
// CHECK-LABEL: func.func @allreduce_tensor(
func.func @allreduce_tensor(
// CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
@@ -97,7 +97,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
// CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
// CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
// CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32>
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32>
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32>
// CHECK: return [[v2]] : tensor<3x4xf32>
return %0 : tensor<3x4xf32>
}
@@ -114,7 +114,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
// CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
// CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32>
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32>
// CHECK: return [[valloc]] : memref<3x4xf32>
return %0 : memref<3x4xf32>
}
@@ -131,14 +131,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
// CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
// CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64>
// CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64>
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
// CHECK: return [[valloc]] : memref<3x4xf64>
return %0 : memref<3x4xf64>
}
}
// -----
-mesh.mesh @mesh0(shape = 3x4x5)
+shard.grid @grid0(shape = 3x4x5)
// CHECK-LABEL: func @update_halo_1d_first
func.func @update_halo_1d_first(
// CHECK-SAME: [[arg0:%.*]]: memref<120x120x120xi8>
@@ -155,14 +155,14 @@ func.func @update_halo_1d_first(
// CHECK: mpi.recv(
// CHECK-SAME: : memref<3x120x120xi8>, i32, i32
// CHECK: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
- %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8>
+ %res = shard.update_halo %arg0 on @grid0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8>
// CHECK: return [[res:%.*]] : memref<120x120x120xi8>
return %res : memref<120x120x120xi8>
}
// -----
module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
- mesh.mesh @mesh0(shape = 4)
+ shard.grid @grid0(shape = 4)
// CHECK-LABEL: func @update_halo_1d_with_zero
func.func @update_halo_1d_with_zero (
// CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
@@ -179,7 +179,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
// CHECK: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
// CHECK: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1]>>
// CHECK: memref.dealloc [[valloc]] : memref<2x120x120xi8>
- %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
+ %res = shard.update_halo %arg0 on @grid0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
// CHECK: return [[varg0]] : memref<120x120x120xi8>
return %res : memref<120x120x120xi8>
}
@@ -187,7 +187,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
// -----
module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
- mesh.mesh @mesh0(shape = 3x4x5)
+ shard.grid @grid0(shape = 3x4x5)
// CHECK-LABEL: func @update_halo_3d
func.func @update_halo_3d(
// CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
@@ -236,7 +236,7 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// CHECK: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
// CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v2]]) : memref<2x120x120xi8>, i32, i32
// CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
- %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
+ %res = shard.update_halo %arg0 on @grid0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
// CHECK: return [[varg0]] : memref<120x120x120xi8>
return %res : memref<120x120x120xi8>
}
@@ -291,18 +291,18 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// CHECK: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]], [[v3]]) : memref<2x120x120xi8>, i32, i32
// CHECK: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
// CHECK: [[v4:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
- %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
+ %res = shard.update_halo %arg0 on @grid0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
// CHECK: return [[v4]] : tensor<120x120x120xi8>
return %res : tensor<120x120x120xi8>
}
}
// -----
-mesh.mesh @mesh0(shape = 2x2x4)
+shard.grid @grid0(shape = 2x2x4)
// CHECK-LABEL: func.func @return_sharding(
// CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
-func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sharding) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] : !mesh.sharding
+func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !shard.sharding) {
+ %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] : !shard.sharding
// CHECK: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
// CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
// CHECK: [[vcm1_i16:%.*]] = arith.constant -1 : i16
@@ -316,13 +316,13 @@ func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sh
// CHECK: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
// CHECK: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor<?x?xi64>
// CHECK: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
- return %arg0, %sharding : tensor<2x4xf32>, !mesh.sharding
+ return %arg0, %sharding : tensor<2x4xf32>, !shard.sharding
}
// CHECK-LABEL: func.func @return_sharding_halos(
// CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
-func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !mesh.sharding) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !mesh.sharding
+func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !shard.sharding) {
+ %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !shard.sharding
// CHECK: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
// CHECK: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
// CHECK: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
@@ -336,13 +336,13 @@ func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !m
// CHECK: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor<?x?xi64>
// CHECK: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
// CHECK: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
- return %arg0, %sharding : tensor<6x8xf32>, !mesh.sharding
+ return %arg0, %sharding : tensor<6x8xf32>, !shard.sharding
}
// CHECK-LABEL: func.func @return_sharding_offs(
// CHECK-SAME: [[varg0:%.*]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
-func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !mesh.sharding) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !mesh.sharding
+func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !shard.sharding) {
+ %sharding = shard.sharding @grid0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !shard.sharding
// CHECK: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
// CHECK: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
// CHECK: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
@@ -362,5 +362,5 @@ func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !me
// CHECK: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
// CHECK: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor<?x?xi64>
// CHECK: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
- return %arg0, %sharding : tensor<?x?xf32>, !mesh.sharding
+ return %arg0, %sharding : tensor<?x?xf32>, !shard.sharding
}
diff --git a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir
similarity index 62%
rename from mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
rename to mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir
index 156bbfb54845b..9729d2bfb384e 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shardshape-to-mpi.mlir
@@ -1,21 +1,21 @@
-// RUN: mlir-opt %s --convert-mesh-to-mpi -canonicalize | FileCheck %s
+// RUN: mlir-opt %s --convert-shard-to-mpi -canonicalize | FileCheck %s
module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
- // CHECK: mesh.mesh @mesh0
- mesh.mesh @mesh0(shape = 3x4x5)
+ // CHECK: shard.grid @grid0
+ shard.grid @grid0(shape = 3x4x5)
- // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @mesh0
+ // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @grid0
// all shards are equal
// CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) {
func.func @shard_shape_equal() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
%c9 = arith.constant 9 : index
%c12 = arith.constant 12 : index
// CHECK: [[vc3:%.*]] = arith.constant 3 : index
- %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
// CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}
@@ -23,13 +23,13 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// last shard in last dim gets an extra element
// CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) {
func.func @shard_shape_odd_1() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
%c9 = arith.constant 9 : index
%c12 = arith.constant 12 : index
// CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
// CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
- %1:3 = mesh.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ %1:3 = shard.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
// CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}
@@ -37,11 +37,11 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// In the second dimension the shard sizes are now [3 4 4 4]
// CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
func.func @shard_shape_odd_2() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
%c9 = arith.constant 9 : index
// CHECK: [[vc3:%.*]] = arith.constant 3 : index
- %1:3 = mesh.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ %1:3 = shard.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
// CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}
@@ -49,11 +49,11 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// In the first dimension the shard sizes are now [3 4 4]
// CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
func.func @shard_shape_odd_3() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]] : !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
// CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
// CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
- %1:3 = mesh.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ %1:3 = shard.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
// CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}
@@ -61,14 +61,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// extract from sharded_dims_offsets
// CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]]
- sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %sharding = shard.sharding @grid0 split_axes = [[0], [1], [2]]
+ sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !shard.sharding
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
%c9 = arith.constant 9 : index
%c12 = arith.constant 12 : index
// CHECK: [[vc3:%.*]] = arith.constant 3 : index
// CHECK: [[vc2:%.*]] = arith.constant 2 : index
- %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ %1:3 = shard.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
// CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}
diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.mlir b/mlir/test/Dialect/Arith/shard-partition.mlir
similarity index 50%
rename from mlir/test/Dialect/Arith/mesh-spmdize.mlir
rename to mlir/test/Dialect/Arith/shard-partition.mlir
index 6b55dd533a92c..5f3bca741e9d5 100644
--- a/mlir/test/Dialect/Arith/mesh-spmdize.mlir
+++ b/mlir/test/Dialect/Arith/shard-partition.mlir
@@ -1,17 +1,17 @@
// RUN: mlir-opt \
-// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization))" \
+// RUN: --pass-pipeline="builtin.module(func.func(shard-partition))" \
// RUN: %s | FileCheck %s
-mesh.mesh @mesh4x4(shape = 4x4)
+shard.grid @grid4x4(shape = 4x4)
-// CHECK-LABEL: func @test_spmdize_constant
+// CHECK-LABEL: func @test_partition_constant
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> :
// tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 :
// i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
-func.func @test_spmdize_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} {
+func.func @test_partition_constant() ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} {
%cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
+ %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+ %sharding_annotated_1 = shard.shard %cst to %sharding_1 : tensor<1024x1024xf32>
%ci = arith.constant 434 : i32
return %sharding_annotated_1 : tensor<1024x1024xf32>
}
diff --git a/mlir/test/Dialect/Arith/sharding-propagation.mlir b/mlir/test/Dialect/Arith/sharding-propagation.mlir
index 19eb340549b0b..d62905dde7bb5 100644
--- a/mlir/test/Dialect/Arith/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Arith/sharding-propagation.mlir
@@ -1,27 +1,27 @@
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
-mesh.mesh @mesh4x4(shape = 4x4)
+shard.grid @grid4x4(shape = 4x4)
// CHECK-LABEL: func.func @test_shard_constant() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} {
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharding_annotated:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32
// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_1:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = shard.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_3:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = shard.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_5:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = shard.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_7:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = shard.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
// CHECK-NEXT: return [[vsharding_annotated_8]] : tensor<1024x1024xf32>
func.func @test_shard_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
%cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+ %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+ %sharding_annotated_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
%ci = arith.constant 43.4e+00 : f32
%o1 = tensor.empty() : tensor<1024x1024xf32>
%res = linalg.add ins(%sharding_annotated_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
@@ -30,25 +30,25 @@ func.func @test_shard_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emi
// CHECK-LABEL: func.func @test_shard_constant_back() -> tensor<1024x1024xf32> attributes {llvm.emit_c_interface} {
// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharding_annotated:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 4.340000e+01 : f32
// CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_1:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = mesh.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_3:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_5:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = mesh.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_1:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharding_annotated_2:%.*]] = shard.shard [[v0]] to [[vsharding_1]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_3:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharding_annotated_4:%.*]] = shard.shard [[vsharding_annotated]] to [[vsharding_3]] annotate_for_users : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_5:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharding_annotated_6:%.*]] = shard.shard [[vsharding_annotated_2]] to [[vsharding_5]] annotate_for_users : tensor<1024x1024xf32>
// CHECK-NEXT: [[v1:%.*]] = linalg.add ins([[vsharding_annotated_4]], [[vcst_0]] : tensor<1024x1024xf32>, f32) outs([[vsharding_annotated_6]] : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
-// CHECK-NEXT: [[vsharding_7:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
-// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = mesh.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
+// CHECK-NEXT: [[vsharding_7:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+// CHECK-NEXT: [[vsharding_annotated_8:%.*]] = shard.shard [[v1]] to [[vsharding_7]] : tensor<1024x1024xf32>
func.func @test_shard_constant_back() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
%cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
%ci = arith.constant 43.4e+00 : f32
%o1 = tensor.empty() : tensor<1024x1024xf32>
%res = linalg.add ins(%cst_1, %ci : tensor<1024x1024xf32>, f32) outs(%o1 : tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
- %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %res to %sharding_1 : tensor<1024x1024xf32>
+ %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+ %sharding_annotated_1 = shard.shard %res to %sharding_1 : tensor<1024x1024xf32>
return %sharding_annotated_1 : tensor<1024x1024xf32>
}
diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
deleted file mode 100644
index 5297eeb666c1e..0000000000000
--- a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
+++ /dev/null
@@ -1,42 +0,0 @@
-// RUN: mlir-opt \
-// RUN: --verify-each \
-// RUN: --pass-pipeline="builtin.module(func.func(sharding-propagation))" \
-// RUN: %s | FileCheck %s
-
-mesh.mesh @mesh_2(shape = 2)
-
-// CHECK-LABEL: func @matmul_shard_prallel_axis
-func.func @matmul_shard_prallel_axis(
- // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>,
- %arg0 : tensor<2x3xf32>,
- // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>,
- %arg1 : tensor<3x2xf32>,
- // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32>
- %out_dps: tensor<2x2xf32>
-) -> tensor<2x2xf32> {
- // CHECK: %[[SIN1_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
- // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32>
- // CHECK: %[[SIN1_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
- // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32>
- // CHECK: %[[SIN2_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding
- // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32>
- // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
- // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32>
- %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding
- %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32>
-
- // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>)
- // CHECK-SAME: outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
- %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
- outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
-
- // CHECK: %[[SRES_ANNOTATED_0:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[0]] : !mesh.sharding
- // CHECK-NEXT: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to %[[SRES_ANNOTATED_0]] : tensor<2x2xf32>
- // CHECK: %[[SRES_ANNOTATED_1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[}}[]] : !mesh.sharding
- // CHECK-NEXT: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to %[[SRES_ANNOTATED_1]] annotate_for_users : tensor<2x2xf32>
- %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding
- %res_sharded = mesh.shard %res to %sres_sharded annotate_for_users : tensor<2x2xf32>
-
- // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32>
- return %res_sharded : tensor<2x2xf32>
-}
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/shard-partition.mlir
similarity index 50%
rename from mlir/test/Dialect/Linalg/mesh-spmdization.mlir
rename to mlir/test/Dialect/Linalg/shard-partition.mlir
index ce12b296df1fa..aee97079fb197 100644
--- a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Linalg/shard-partition.mlir
@@ -1,15 +1,15 @@
// RUN: mlir-opt \
-// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
+// RUN: --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \
// RUN: --split-input-file \
// RUN: %s | FileCheck %s
// CHECK: #[[$MAP_IDENTITY_1D:.*]] = affine_map<(d0) -> (d0)>
#map_identity_1d = affine_map<(d0) -> (d0)>
-mesh.mesh @mesh_1d(shape = 2)
+shard.grid @grid_1d(shape = 2)
-// CHECK-LABEL: func @elementwise_static_1d_mesh_static_1d_tensor
-func.func @elementwise_static_1d_mesh_static_1d_tensor(
+// CHECK-LABEL: func @elementwise_static_1d_grid_static_1d_tensor
+func.func @elementwise_static_1d_grid_static_1d_tensor(
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1xi8>,
%in1: tensor<2xi8>,
// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xi8>,
@@ -18,13 +18,13 @@ func.func @elementwise_static_1d_mesh_static_1d_tensor(
%dps_out: tensor<2xi8>
// CHECK-SAME: -> tensor<1xi8> {
) -> tensor<2xi8> {
- %sharding = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %in1_sharded1 = mesh.shard %in1 to %sharding : tensor<2xi8>
- %in1_sharded2 = mesh.shard %in1_sharded1 to %sharding annotate_for_users : tensor<2xi8>
- %in2_sharded1 = mesh.shard %in2 to %sharding : tensor<2xi8>
- %in2_sharded2 = mesh.shard %in2_sharded1 to %sharding annotate_for_users : tensor<2xi8>
- %dps_out_sharded1 = mesh.shard %dps_out to %sharding : tensor<2xi8>
- %dps_out_shared2 = mesh.shard %dps_out_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+ %sharding = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %in1_sharded1 = shard.shard %in1 to %sharding : tensor<2xi8>
+ %in1_sharded2 = shard.shard %in1_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+ %in2_sharded1 = shard.shard %in2 to %sharding : tensor<2xi8>
+ %in2_sharded2 = shard.shard %in2_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+ %dps_out_sharded1 = shard.shard %dps_out to %sharding : tensor<2xi8>
+ %dps_out_shared2 = shard.shard %dps_out_sharded1 to %sharding annotate_for_users : tensor<2xi8>
// CHECK: %[[RES:.*]] = linalg.generic {
// CHECK-SAME: indexing_maps = [#[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]]],
// CHECK-SAME: iterator_types = ["parallel"]}
@@ -39,18 +39,18 @@ func.func @elementwise_static_1d_mesh_static_1d_tensor(
%res_scalar = arith.muli %in1_scalar, %in2_scalar : i8
linalg.yield %res_scalar : i8
} -> tensor<2xi8>
- %res_sharded1 = mesh.shard %res to %sharding : tensor<2xi8>
- %res_shared2 = mesh.shard %res_sharded1 to %sharding annotate_for_users : tensor<2xi8>
+ %res_sharded1 = shard.shard %res to %sharding : tensor<2xi8>
+ %res_shared2 = shard.shard %res_sharded1 to %sharding annotate_for_users : tensor<2xi8>
// CHECK: return %[[RES]] : tensor<1xi8>
return %res_shared2 : tensor<2xi8>
}
// -----
-mesh.mesh @mesh_1d(shape = 4)
+shard.grid @grid_1d(shape = 4)
-// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding
-func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding(
+// CHECK-LABEL: func @matmul_1d_grid_static_tensors_parallel_iterator_sharding
+func.func @matmul_1d_grid_static_tensors_parallel_iterator_sharding(
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1x3xi8>,
%in1: tensor<4x3xi8>,
// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x8xi8>,
@@ -59,32 +59,32 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding(
%dps_out: tensor<4x8xi8>
// CHECK-SAME: -> tensor<1x8xi8> {
) -> tensor<4x8xi8> {
- %sharding = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x3xi8>
- %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x3xi8>
- %sharding2 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<3x8xi8>
- %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<3x8xi8>
- %dps_out_shared1 = mesh.shard %dps_out to %sharding : tensor<4x8xi8>
- %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
+ %sharding = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %in1_shared1 = shard.shard %in1 to %sharding : tensor<4x3xi8>
+ %in1_shared2 = shard.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x3xi8>
+ %sharding2 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %in2_shared1 = shard.shard %in2 to %sharding2 : tensor<3x8xi8>
+ %in2_shared2 = shard.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<3x8xi8>
+ %dps_out_shared1 = shard.shard %dps_out to %sharding : tensor<4x8xi8>
+ %dps_out_shared2 = shard.shard %dps_out_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
// CHECK: %[[RES:.*]] = linalg.matmul
// CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1x3xi8>, tensor<3x8xi8>)
// CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1x8xi8>)
// CHECK-SAME: -> tensor<1x8xi8>
%res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x3xi8>, tensor<3x8xi8>)
outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
- %res_shared1 = mesh.shard %res to %sharding : tensor<4x8xi8>
- %res_shared2 = mesh.shard %res_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
+ %res_shared1 = shard.shard %res to %sharding : tensor<4x8xi8>
+ %res_shared2 = shard.shard %res_shared1 to %sharding annotate_for_users : tensor<4x8xi8>
// CHECK: return %[[RES]] : tensor<1x8xi8>
return %res_shared2 : tensor<4x8xi8>
}
// -----
-mesh.mesh @mesh_1d(shape = 3)
+shard.grid @grid_1d(shape = 3)
-// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding
-func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
+// CHECK-LABEL: func @matmul_1d_grid_static_tensors_reduction_iterator_sharding
+func.func @matmul_1d_grid_static_tensors_reduction_iterator_sharding(
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>,
%in1: tensor<4x6xi8>,
// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>,
@@ -93,19 +93,19 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
%dps_out: tensor<4x8xi8>
// CHECK-SAME: -> tensor<4x8xi8> {
) -> tensor<4x8xi8> {
- %sharding = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %in1_shared1 = mesh.shard %in1 to %sharding : tensor<4x6xi8>
- %in1_shared2 = mesh.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8>
- %sharding2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %in2_shared1 = mesh.shard %in2 to %sharding2 : tensor<6x8xi8>
- %in2_shared2 = mesh.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8>
- %sharding3 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %dps_out_shared1 = mesh.shard %dps_out to %sharding3 : tensor<4x8xi8>
- %dps_out_shared2 = mesh.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
+ %sharding = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %in1_shared1 = shard.shard %in1 to %sharding : tensor<4x6xi8>
+ %in1_shared2 = shard.shard %in1_shared1 to %sharding annotate_for_users : tensor<4x6xi8>
+ %sharding2 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %in2_shared1 = shard.shard %in2 to %sharding2 : tensor<6x8xi8>
+ %in2_shared2 = shard.shard %in2_shared1 to %sharding2 annotate_for_users : tensor<6x8xi8>
+ %sharding3 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %dps_out_shared1 = shard.shard %dps_out to %sharding3 : tensor<4x8xi8>
+ %dps_out_shared2 = shard.shard %dps_out_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
- // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
- // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+ // CHECK-DAG: %[[PROCESS_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index
+ // CHECK-DAG: %[[SHARD_SIZE:.*]] = shard.grid_shape @grid_1d axes = [0] : index
// CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index
// CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) {
// CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8>
@@ -117,21 +117,21 @@ func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
// CHECK: }
// CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>)
// CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8>
- // CHECK: %[[ALL_REDUCED:.*]] = mesh.all_reduce %[[SHARDED_MATMUL]] on @mesh_1d mesh_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8>
+ // CHECK: %[[ALL_REDUCED:.*]] = shard.all_reduce %[[SHARDED_MATMUL]] on @grid_1d grid_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8>
%res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
- %res_shared1 = mesh.shard %res to %sharding3 : tensor<4x8xi8>
- %res_shared2 = mesh.shard %res_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
+ %res_shared1 = shard.shard %res to %sharding3 : tensor<4x8xi8>
+ %res_shared2 = shard.shard %res_shared1 to %sharding3 annotate_for_users : tensor<4x8xi8>
// CHECK: return %[[ALL_REDUCED]] : tensor<4x8xi8>
return %res_shared2 : tensor<4x8xi8>
}
// -----
-mesh.mesh @mesh_1d(shape = 4)
+shard.grid @grid_1d(shape = 4)
-// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis
-func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
+// CHECK-LABEL: func @matmul_1d_grid_static_tensors_parallel_iterator_unsplit_last_axis
+func.func @matmul_1d_grid_static_tensors_parallel_iterator_unsplit_last_axis(
// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x6xi8>,
%in1: tensor<4x6xi8>,
// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<6x8xi8>,
@@ -140,25 +140,25 @@ func.func @matmul_1d_mesh_static_tensors_parallel_iterator_unsplit_last_axis(
%dps_out: tensor<4x8xi8>
// CHECK-SAME: -> tensor<4x8xi8> {
) -> tensor<4x8xi8> {
- %sharding1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding
- %in1_replicated1 = mesh.shard %in1 to %sharding1 : tensor<4x6xi8>
- %in1_replicated2 = mesh.shard %in1_replicated1 to %sharding1 annotate_for_users : tensor<4x6xi8>
- // CHECK: %[[ALL_SLICE1:.*]] = mesh.all_slice %[[IN2]] on @mesh_1d mesh_axes = [0] slice_axis = 1
- %in2_replicated = mesh.shard %in2 to %sharding1 : tensor<6x8xi8>
- %sharding2 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %in2_sharded = mesh.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x8xi8>
- // CHECK: %[[ALL_SLICE2:.*]] = mesh.all_slice %[[DPS_OUT]] on @mesh_1d mesh_axes = [0] slice_axis = 1
- %dps_out_replicated = mesh.shard %dps_out to %sharding1 : tensor<4x8xi8>
- %dps_out_sharded = mesh.shard %dps_out_replicated to %sharding2 annotate_for_users : tensor<4x8xi8>
+ %sharding1 = shard.sharding @grid_1d split_axes = [[], []] : !shard.sharding
+ %in1_replicated1 = shard.shard %in1 to %sharding1 : tensor<4x6xi8>
+ %in1_replicated2 = shard.shard %in1_replicated1 to %sharding1 annotate_for_users : tensor<4x6xi8>
+ // CHECK: %[[ALL_SLICE1:.*]] = shard.all_slice %[[IN2]] on @grid_1d grid_axes = [0] slice_axis = 1
+ %in2_replicated = shard.shard %in2 to %sharding1 : tensor<6x8xi8>
+ %sharding2 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %in2_sharded = shard.shard %in2_replicated to %sharding2 annotate_for_users : tensor<6x8xi8>
+ // CHECK: %[[ALL_SLICE2:.*]] = shard.all_slice %[[DPS_OUT]] on @grid_1d grid_axes = [0] slice_axis = 1
+ %dps_out_replicated = shard.shard %dps_out to %sharding1 : tensor<4x8xi8>
+ %dps_out_sharded = shard.shard %dps_out_replicated to %sharding2 annotate_for_users : tensor<4x8xi8>
// CHECK: %[[MATMUL_RES:.*]] = linalg.matmul
// CHECK-SAME: ins(%[[IN1]], %[[ALL_SLICE1]] : tensor<4x6xi8>, tensor<6x2xi8>)
// CHECK-SAME: outs(%[[ALL_SLICE2]] : tensor<4x2xi8>)
// CHECK-SAME: -> tensor<4x2xi8>
%res = linalg.matmul ins(%in1_replicated2, %in2_sharded : tensor<4x6xi8>, tensor<6x8xi8>)
outs(%dps_out_sharded : tensor<4x8xi8>) -> tensor<4x8xi8>
- // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[MATMUL_RES]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8>
- %res_sharded = mesh.shard %res to %sharding2 : tensor<4x8xi8>
- %res_replicated = mesh.shard %res_sharded to %sharding1 annotate_for_users : tensor<4x8xi8>
+ // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[MATMUL_RES]] on @grid_1d grid_axes = [0] gather_axis = 1 : tensor<4x2xi8> -> tensor<4x8xi8>
+ %res_sharded = shard.shard %res to %sharding2 : tensor<4x8xi8>
+ %res_replicated = shard.shard %res_sharded to %sharding1 annotate_for_users : tensor<4x8xi8>
// CHECK: return %[[ALL_GATHER]] : tensor<4x8xi8>
return %res_replicated : tensor<4x8xi8>
}
diff --git a/mlir/test/Dialect/Linalg/sharding-propagation.mlir b/mlir/test/Dialect/Linalg/sharding-propagation.mlir
new file mode 100644
index 0000000000000..e0ecefcf2d6bd
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/sharding-propagation.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt \
+// RUN: --verify-each \
+// RUN: --pass-pipeline="builtin.module(func.func(sharding-propagation))" \
+// RUN: %s | FileCheck %s
+
+shard.grid @grid_2(shape = 2)
+
+// CHECK-LABEL: func @matmul_shard_prallel_axis
+func.func @matmul_shard_prallel_axis(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>,
+ %arg0 : tensor<2x3xf32>,
+ // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>,
+ %arg1 : tensor<3x2xf32>,
+ // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32>
+ %out_dps: tensor<2x2xf32>
+) -> tensor<2x2xf32> {
+ // CHECK: %[[SIN1_ANNOTATED_0:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding
+ // CHECK-NEXT: %[[IN1_ANNOTATED_0:.*]] = shard.shard %[[IN1]] to %[[SIN1_ANNOTATED_0]] : tensor<2x3xf32>
+ // CHECK: %[[SIN1_ANNOTATED_1:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding
+ // CHECK-NEXT: %[[IN1_ANNOTATED_1:.*]] = shard.shard %[[IN1_ANNOTATED_0]] to %[[SIN1_ANNOTATED_1]] annotate_for_users : tensor<2x3xf32>
+ // CHECK: %[[SIN2_ANNOTATED:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[]] : !shard.sharding
+ // CHECK-NEXT: %[[IN2_ANNOTATED:.*]] = shard.shard %[[IN2]] to %[[SIN2_ANNOTATED]] annotate_for_users : tensor<3x2xf32>
+ // CHECK: %[[SDPS_OUT_ANNOTATED:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding
+ // CHECK-NEXT: %[[DPS_OUT_ANNOTATED:.*]] = shard.shard %[[DPS_OUT]] to %[[SDPS_OUT_ANNOTATED]] annotate_for_users : tensor<2x2xf32>
+ %sarg0_sharded = shard.sharding @grid_2 split_axes = [[0]] : !shard.sharding
+ %arg0_sharded = shard.shard %arg0 to %sarg0_sharded : tensor<2x3xf32>
+
+ // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+ // CHECK-SAME: outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+ %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+ outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+
+ // CHECK: %[[SRES_ANNOTATED_0:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[0]] : !shard.sharding
+ // CHECK-NEXT: %[[RES_ANNOTATED_0:.*]] = shard.shard %[[RES]] to %[[SRES_ANNOTATED_0]] : tensor<2x2xf32>
+ // CHECK: %[[SRES_ANNOTATED_1:.*]] = shard.sharding @grid_2 split_axes = {{\[}}[]] : !shard.sharding
+ // CHECK-NEXT: %[[RES_ANNOTATED_1:.*]] = shard.shard %[[RES_ANNOTATED_0]] to %[[SRES_ANNOTATED_1]] annotate_for_users : tensor<2x2xf32>
+ %sres_sharded = shard.sharding @grid_2 split_axes = [[]] : !shard.sharding
+ %res_sharded = shard.shard %res to %sres_sharded annotate_for_users : tensor<2x2xf32>
+
+ // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32>
+ return %res_sharded : tensor<2x2xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
deleted file mode 100644
index aff07bbf8a214..0000000000000
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ /dev/null
@@ -1,248 +0,0 @@
-// RUN: mlir-opt --canonicalize %s | FileCheck %s
-
-mesh.mesh @mesh0(shape = 2x4)
-
-// CHECK-LABEL: func @all_reduce_empty_mesh_axes
-func.func @all_reduce_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.all_reduce
- %0 = mesh.all_reduce %arg0 on @mesh0
- mesh_axes = []
- : tensor<4xf32> -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @all_reduce_empty_mesh_axes_different_return_type
-func.func @all_reduce_empty_mesh_axes_different_return_type(
- %arg0 : tensor<4xf32>) -> tensor<4xf64> {
-// CHECK: mesh.all_reduce
- %0 = mesh.all_reduce %arg0 on @mesh0
-// CHECK-NOT: mesh_axes
- mesh_axes = []
- : tensor<4xf32> -> tensor<4xf64>
- return %0 : tensor<4xf64>
-}
-
-// CHECK-LABEL: func @all_reduce_default_reduction
-func.func @all_reduce_default_reduction(
- %arg0 : tensor<4xf32>) -> tensor<4xf64> {
- %0 = mesh.all_reduce %arg0 on @mesh0
- mesh_axes = [0]
-// CHECK-NOT: reduction
- reduction = sum
- : tensor<4xf32> -> tensor<4xf64>
- return %0 : tensor<4xf64>
-}
-
-// CHECK-LABEL: func @all_to_all_empty_mesh_axes
-func.func @all_to_all_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32>
- %arg0 : tensor<8xf32>) -> tensor<8xf32> {
-// CHECK-NOT: mesh.all_to_all
- %0 = mesh.all_to_all %arg0 on @mesh0
- mesh_axes = []
- split_axis = 0
- concat_axis = 0
- : tensor<8xf32> -> tensor<8xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<8xf32>
-}
-
-// CHECK-LABEL: func @all_gather_empty_mesh_axes
-func.func @all_gather_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.all_gather
- %0 = mesh.all_gather %arg0 on @mesh0
- mesh_axes = []
- gather_axis = 0
- : tensor<4xf32> -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @all_slice_empty_mesh_axes
-func.func @all_slice_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.scatter
- %0 = mesh.all_slice %arg0 on @mesh0
- mesh_axes = []
- slice_axis = 0
- : tensor<4xf32> -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @broadcast_empty_mesh_axes
-func.func @broadcast_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.broadcast
- %0 = mesh.broadcast %arg0 on @mesh0
- mesh_axes = []
- root = []
- : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @gather_empty_mesh_axes
-func.func @gather_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.gather
- %0 = mesh.gather %arg0 on @mesh0
- mesh_axes = []
- gather_axis = 0
- root = []
- : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @receive_empty_mesh_axes
-func.func @receive_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.recv
- %0 = mesh.recv %arg0 on @mesh0
- mesh_axes = []
- : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @reduce_empty_mesh_axes
-func.func @reduce_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.reduce
- %0 = mesh.reduce %arg0 on @mesh0
- mesh_axes = []
- root = []
- : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
-func.func @reduce_scatter_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.reduce_scatter
- %0 = mesh.reduce_scatter %arg0 on @mesh0
- mesh_axes = []
- scatter_axis = 0
- : tensor<4xf32> -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_different_return_type
-func.func @reduce_scatter_empty_mesh_axes_different_return_type(
- %arg0 : tensor<4xf32>) -> tensor<4xf64> {
-// CHECK: mesh.reduce_scatter
- %0 = mesh.reduce_scatter %arg0 on @mesh0
-// CHECK-NOT: mesh_axes
- mesh_axes = []
- scatter_axis = 0
- : tensor<4xf32> -> tensor<4xf64>
- return %0 : tensor<4xf64>
-}
-
-// CHECK-LABEL: func @reduce_scatter_default_reduction
-func.func @reduce_scatter_default_reduction(
- %arg0 : tensor<4xf32>) -> tensor<2xf64> {
- %0 = mesh.reduce_scatter %arg0 on @mesh0
- mesh_axes = [0]
-// CHECK-NOT: reduction
- reduction = sum
- scatter_axis = 0
- : tensor<4xf32> -> tensor<2xf64>
- return %0 : tensor<2xf64>
-}
-
-// CHECK-LABEL: func @scatter_empty_mesh_axes
-func.func @scatter_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.scatter
- %0 = mesh.scatter %arg0 on @mesh0
- mesh_axes = []
- scatter_axis = 0
- root = []
- : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-// CHECK-LABEL: func @send_empty_mesh_axes
-func.func @send_empty_mesh_axes(
-// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
- %arg0 : tensor<4xf32>) -> tensor<4xf32> {
-// CHECK-NOT: mesh.send
- %0 = mesh.send %arg0 on @mesh0
- mesh_axes = []
- destination = []
- : (tensor<4xf32>) -> tensor<4xf32>
-// CHECK: return %[[ARG]]
- return %0 : tensor<4xf32>
-}
-
-mesh.mesh @mesh4x4(shape = 4x4)
-// CHECK-LABEL: func @test_halo_sizes
-func.func @test_halo_sizes() -> !mesh.sharding {
- %c2_i64 = arith.constant 2 : i64
- // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !mesh.sharding
- %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !mesh.sharding
- return %sharding : !mesh.sharding
-}
-
-// CHECK-LABEL: func @test_shard_offs
-func.func @test_shard_offs() -> !mesh.sharding {
- %c2_i64 = arith.constant 2 : i64
- // CHECK mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !mesh.sharding
- %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !mesh.sharding
- return %sharding : !mesh.sharding
-}
-
-// CHECK-LABEL: func @test_duplicate_shardops
-func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
- // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding
- %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
- %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
- %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
- %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
- %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
- // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
- %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
- // CHECK-NEXT: return [[vsharding_annotated]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
- return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
-}
-
-// CHECK-LABEL: func @test_duplicate_shardops_diff
-func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
- // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0]] : !mesh.sharding
- %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
- %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- // CHECK-NEXT: [[vsharding_0:%.*]] = mesh.sharding @mesh4x4 split_axes = {{\[\[}}0, 1]] : !mesh.sharding
- %sharding_2 = mesh.sharding @mesh4x4 split_axes = [[0, 1]] : !mesh.sharding
- // CHECK-NEXT: [[vsharding_annotated:%.*]] = mesh.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32>
- %sharding_annotated_2 = mesh.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
- %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %sharding_3 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
- %sharding_annotated_3 = mesh.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
- // CHECK-NEXT: [[vsharding_annotated_1:%.*]] = mesh.shard [[vsharding_annotated]] to [[vsharding]] : tensor<1024x1024xf32>
- %sharding_annotated_1 = mesh.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
- // CHECK-NEXT: return [[vsharding_annotated_1]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
- return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
-}
diff --git a/mlir/test/Dialect/Mesh/folding.mlir b/mlir/test/Dialect/Mesh/folding.mlir
deleted file mode 100644
index 369f316d0f797..0000000000000
--- a/mlir/test/Dialect/Mesh/folding.mlir
+++ /dev/null
@@ -1,22 +0,0 @@
-// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
-
-mesh.mesh @mesh0(shape = 4x?x2)
-mesh.mesh @mesh1(shape = 2x3)
-
-// CHECK-LABEL: func.func @mesh_shape_op_folding
-func.func @mesh_shape_op_folding() -> (index, index) {
- // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
- // CHECK: %[[AXIS_1_SIZE:.*]] = mesh.mesh_shape @mesh0 axes = [1] : index
- %0:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index, index
- // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
- return %0#0, %0#1 : index, index
-}
-
-// CHECK-LABEL: func.func @mesh_shape_op_folding_all_axes_static_mesh
-func.func @mesh_shape_op_folding_all_axes_static_mesh() -> (index, index) {
- // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
- // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
- %0:2 = mesh.mesh_shape @mesh1 : index, index
- // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
- return %0#0, %0#1 : index, index
-}
diff --git a/mlir/test/Dialect/Mesh/inlining.mlir b/mlir/test/Dialect/Mesh/inlining.mlir
deleted file mode 100644
index c41a709e1a4eb..0000000000000
--- a/mlir/test/Dialect/Mesh/inlining.mlir
+++ /dev/null
@@ -1,15 +0,0 @@
-// RUN: mlir-opt -inline %s | FileCheck %s
-
-mesh.mesh @mesh0(shape = 4x?x2)
-
-func.func private @mesh_to_inline() -> (index, index) {
- %0:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index, index
- return %0#0, %0#1 : index, index
-}
-// CHECK-LABEL: func.func @main
-func.func @main() -> (index, index) {
- // CHECK-NEXT: %[[AXIS_SIZE:.*]]:2 = mesh.mesh_shape @mesh0 axes = [2, 1] : index
- %0:2 = func.call @mesh_to_inline() : () -> (index, index)
- // CHECK-NEXT: return %[[AXIS_SIZE]]#0, %[[AXIS_SIZE]]#1
- return %0#0, %0#1 : index, index
-}
diff --git a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
deleted file mode 100644
index e23cfd79a4274..0000000000000
--- a/mlir/test/Dialect/Mesh/process-multi-index-op-lowering.mlir
+++ /dev/null
@@ -1,23 +0,0 @@
-// RUN: mlir-opt -test-mesh-process-multi-index-op-lowering %s | FileCheck %s
-
-mesh.mesh @mesh2d(shape = ?x?)
-
-// CHECK-LABEL: func.func @multi_index_2d_mesh
-func.func @multi_index_2d_mesh() -> (index, index) {
- // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
- // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
- // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
- %0:2 = mesh.process_multi_index on @mesh2d : index, index
- // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
- return %0#0, %0#1 : index, index
-}
-
-// CHECK-LABEL: func.func @multi_index_2d_mesh_single_inner_axis
-func.func @multi_index_2d_mesh_single_inner_axis() -> index {
- // CHECK: %[[LINEAR_IDX:.*]] = mesh.process_linear_index on @mesh2d : index
- // CHECK: %[[MESH_SHAPE:.*]]:2 = mesh.mesh_shape @mesh2d : index, index
- // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[MESH_SHAPE]]#0, %[[MESH_SHAPE]]#1) : index, index
- %0 = mesh.process_multi_index on @mesh2d axes = [0] : index
- // CHECK: return %[[MULTI_IDX]]#0 : index
- return %0 : index
-}
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
deleted file mode 100644
index 5e62c929aa4ff..0000000000000
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ /dev/null
@@ -1,168 +0,0 @@
-// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s
-
-mesh.mesh @mesh_1d(shape = 2)
-mesh.mesh @mesh_1d_dynamic(shape = ?)
-
-// CHECK-LABEL: func @same_source_and_target_sharding
-func.func @same_source_and_target_sharding(
- // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
- %arg0: tensor<2xf32>
-) -> tensor<2xf32> {
- %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xf32>
- // CHECK: return %[[ARG]]
- return %1 : tensor<2xf32>
-}
-
-// CHECK-LABEL: func @identical_source_and_target_sharding
-func.func @identical_source_and_target_sharding(
- // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
- %arg0: tensor<2xf32>
-) -> tensor<2xf32> {
- %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2xf32>
- %1 = mesh.shard %0 to %s0 annotate_for_users : tensor<2xf32>
- // CHECK: return %[[ARG]]
- return %1 : tensor<2xf32>
-}
-
-// CHECK-LABEL: func @split_replicated_tensor_axis
-func.func @split_replicated_tensor_axis(
- // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
- %arg0: tensor<3x14xf32>
-) -> tensor<3x14xf32> {
- // CHECK: %[[ALL_SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 1
- // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
- // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
- %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<3x14xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<3x14xf32>
- // CHECK: return %[[RESULT]] : tensor<3x14xf32>
- return %1 : tensor<3x14xf32>
-}
-
-// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic
-func.func @split_replicated_tensor_axis_dynamic(
- // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
- %arg0: tensor<?x3x?xf32>
-) -> tensor<?x3x?xf32> {
- // CHECK: %[[RESULT:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] slice_axis = 0
- // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
- %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [], []] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<?x3x?xf32>
- %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x3x?xf32>
- // CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
- return %1 : tensor<?x3x?xf32>
-}
-
-// CHECK-LABEL: func @move_split_axis
-func.func @move_split_axis(
- // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
- %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
- // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
- // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
- // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
- %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
- // CHECK: return %[[RES]] : tensor<10x14xf32>
- return %1 : tensor<10x14xf32>
-}
-
-// CHECK-LABEL: func @move_split_axis_dynamic_mesh
-func.func @move_split_axis_dynamic_mesh(
- // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
- %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
- // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
- // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
- // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
- // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
- %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
- %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[], [0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
- // CHECK: return %[[RES]] : tensor<10x14xf32>
- return %1 : tensor<10x14xf32>
-}
-
-// CHECK-LABEL: func @move_split_dynamic_axis
-func.func @move_split_dynamic_axis(
- // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
- %arg0: tensor<?x14xf32>
-) -> tensor<?x14xf32> {
- // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
- // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
- %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
- // CHECK: return %[[RES]] : tensor<?x14xf32>
- return %1 : tensor<?x14xf32>
-}
-
-// CHECK-LABEL: func @unshard_static_axis
-func.func @unshard_static_axis(
- // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
- %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
- // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
- // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
- %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
- // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
- return %1 : tensor<10x14xf32>
-}
-
-// CHECK-LABEL: func @unshard_static_last_axis
-func.func @unshard_static_last_axis(
- // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
- %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
- // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
- // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
- %s0 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[], []] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
- // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
- return %1 : tensor<10x14xf32>
-}
-
-// CHECK-LABEL: func @unshard_dynamic_axis
-func.func @unshard_dynamic_axis(
- // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
- %arg0: tensor<?x14xf32>
-) -> tensor<?x14xf32> {
- // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
- %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<?x14xf32>
- %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
- // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
- return %1 : tensor<?x14xf32>
-}
-
-// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis
-func.func @unshard_static_axis_on_dynamic_mesh_axis(
-// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
- %arg0: tensor<10x14xf32>
-) -> tensor<10x14xf32> {
- // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
- // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
- // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
- %s0 = mesh.sharding @mesh_1d_dynamic split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<10x14xf32>
- %s1 = mesh.sharding @mesh_1d_dynamic split_axes = [[]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
- // CHECK: return %[[RES]] : tensor<10x14xf32>
- return %1 : tensor<10x14xf32>
-}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
deleted file mode 100644
index 0881d994d60e7..0000000000000
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ /dev/null
@@ -1,301 +0,0 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s
-
-mesh.mesh @mesh_2(shape = 2)
-mesh.mesh @mesh_1d(shape = ?)
-mesh.mesh @mesh_2d(shape = 2x4)
-mesh.mesh @mesh_3d(shape = ?x?x?)
-
-// CHECK-LABEL: func.func @element_wise_empty_sharding_info
-func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: tosa.sigmoid
- %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: return
- return %0 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_def
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
- %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
- %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V2]]
- return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_use
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
- %s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
- %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V2]]
- return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_graph_output
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
- %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
- %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V3]]
- return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @element_wise_on_graph_input
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
- %s0 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.sigmoid %[[V1]]
- %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V3]]
- return %1 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @arrow_structure
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @arrow_structure(%arg0: tensor<8x16xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
- // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG]] to %[[S1]] annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.tanh %[[V1]]
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S1]] : tensor<8x16xf32>
- %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to %[[S1]] annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]]
- // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to %[[S1]] : tensor<8x16xf32>
- %1 = tosa.abs %0: (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-NEXT: %[[ZP1:.*]] = mesh.shard %arg1 to %[[S3]] annotate_for_users : tensor<1xf32>
- // CHECK-NEXT: %[[ZP2:.*]] = mesh.shard %arg2 to %[[S3]] annotate_for_users : tensor<1xf32>
- // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]], %[[ZP1]], %[[ZP2]]
- // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to %[[S1]] : tensor<8x16xf32>
- %2 = tosa.negate %0, %arg1, %arg2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
- %s3 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
- %3 = mesh.shard %2 to %s3 : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V6]], %[[V8]]
- return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
-func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
- // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
- %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] : tensor<2x16x32xf32>
- %s1 = mesh.sharding @mesh_2d split_axes = [[0], [1]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 : tensor<2x16x32xf32>
- // CHECK-NEXT: return %[[V3]]
- return %1 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_n
-// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32>
-func.func @matmul_on_def_shard_m_and_n(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
- // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
- // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32>
- // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [], [1]] : !mesh.sharding
- // CHECK: [[vsharded_1:%.*]] = mesh.shard [[varg1]] to [[vsharding_0]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK: [[vsharded_3:%.*]] = mesh.shard [[varg2]] to [[vsharding_2]] annotate_for_users : tensor<1xf32>
- // CHECK: [[v0:%.*]] = tosa.matmul
- %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
- // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0], [1]] : !mesh.sharding
- // CHECK: [[vsharded_5:%.*]] = mesh.shard [[v0]] to [[vsharding_4]] : tensor<2x16x32xf32>
- %s1 = mesh.sharding @mesh_2d split_axes = [[], [0], [1]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 : tensor<2x16x32xf32>
- // CHECK-NEXT: return [[vsharded_5]]
- return %1 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
-// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32>
-func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
- // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0], [1]] : !mesh.sharding
- %s0 = mesh.sharding @mesh_2d split_axes = [[], [0], [1]] : !mesh.sharding
- // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x16x8xf32>
- %arg0_s = mesh.shard %arg0 to %s0 : tensor<2x16x8xf32>
- // CHECK: [[vsharded_0:%.*]] = mesh.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32>
- // CHECK: [[vsharding_1:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] : !mesh.sharding
- // CHECK: [[vsharded_2:%.*]] = mesh.shard [[varg1]] to [[vsharding_1]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK: [[vsharding_3:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK: [[vsharded_4:%.*]] = mesh.shard [[varg2]] to [[vsharding_3]] annotate_for_users : tensor<1xf32>
- // CHECK: [[v0:%.*]] = tosa.matmul
- // CHECK: [[vsharding_5:%.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
- // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding_5]] : tensor<2x16x32xf32>
- %0 = tosa.matmul %arg0_s, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
- // CHECK: return [[vsharded_6]]
- return %0 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
-func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1], [0]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
- %s0 = mesh.sharding @mesh_2d split_axes = [[], [1], [0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32>
- // CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [0]] : !mesh.sharding
- // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
- %s1 = mesh.sharding @mesh_2d split_axes = [[], [0]] : !mesh.sharding
- %1 = mesh.shard %arg1 to %s1 annotate_for_users : tensor<2x8x32xf32>
- // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
- %2 = tosa.matmul %0, %1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
- // CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}], [1]] : !mesh.sharding
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32>
- // CHECK-NEXT: return %[[V3]]
- return %2 : tensor<2x16x32xf32>
-}
-
-// CHECK-LABEL: func.func @resolve_conflicting_annotations
-func.func @resolve_conflicting_annotations(
- // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>,
- %arg0: tensor<2x3xf32>,
- // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>,
- %arg1: tensor<3x2xf32>,
- // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32>
- %out_dps: tensor<2x2xf32>
-// CHECK-SAME: ) -> tensor<2x2xf32> {
-) -> tensor<2x2xf32> {
- // CHECK: %[[SIN1_SHARDED1:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK-NEXT: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to %[[SIN1_SHARDED1]] : tensor<2x3xf32>
- // CHECK: %[[SIN2_SHARDED:.*]] = mesh.sharding @mesh_2 split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-NEXT: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x3xf32>
- // CHECK-NEXT: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<3x2xf32>
- // CHECK-NEXT: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x2xf32>
- %sarg0_sharded = mesh.sharding @mesh_2 split_axes = [[0]] : !mesh.sharding
- %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2x3xf32>
- // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
- // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
- %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
- outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
- // CHECK-NEXT: %[[RES:.*]] = mesh.shard %[[MATMUL]] to %[[SIN2_SHARDED]] : tensor<2x2xf32>
- %sres_sharded = mesh.sharding @mesh_2 split_axes = [[]] : !mesh.sharding
- %res_sharded = mesh.shard %res to %sres_sharded : tensor<2x2xf32>
- // CHECK: return %[[RES]] : tensor<2x2xf32>
- return %res_sharded : tensor<2x2xf32>
-}
-
-// https://arxiv.org/abs/2211.05102 Figure 2(a)
-// The sharding propagation results in unnecessary reshards,
-// an optimization pass should be able to remove them.
-// CHECK-LABEL: func.func @mlp_1d_weight_stationary
-// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32>
-func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
- %s0 = mesh.sharding @mesh_1d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
- %sharded0 = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32>
- %sharded1 = mesh.shard %arg1 to %s0 : tensor<2x8x32xf32>
- // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding
- // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32>
- // CHECK: [[vsharded_0:%.*]] = mesh.shard [[varg1]] to [[vsharding]] : tensor<2x8x32xf32>
- // CHECK: [[vsharded_1:%.*]] = mesh.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32>
- // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}], [0, 1, 2]] : !mesh.sharding
- // CHECK: [[vsharded_3:%.*]] = mesh.shard [[vsharded_0]] to [[vsharding_2]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK: [[vsharded_5:%.*]] = mesh.shard [[varg3]] to [[vsharding_4]] annotate_for_users : tensor<1xf32>
- // CHECK: [[v0:%.*]] = tosa.matmul
- %1 = tosa.matmul %sharded0, %sharded1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32>
- // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding_4]] : tensor<2x4x32xf32>
- // CHECK: [[vsharded_7:%.*]] = mesh.shard [[vsharded_6]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32>
- // CHECK: [[v1:%.*]] = tosa.sigmoid [[vsharded_7]] : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
- // CHECK: [[vsharded_8:%.*]] = mesh.shard [[v1]] to [[vsharding_4]] : tensor<2x4x32xf32>
- %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
- %sharding = mesh.sharding @mesh_1d split_axes = [[], [0, 1, 2]] : !mesh.sharding
- // CHECK: [[vsharded_9:%.*]] = mesh.shard [[varg2]] to [[vsharding_2]] : tensor<2x32x8xf32>
- %sharded2 = mesh.shard %arg2 to %sharding : tensor<2x32x8xf32>
- // CHECK: [[vsharded_10:%.*]] = mesh.shard [[vsharded_8]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32>
- // CHECK: [[vsharded_11:%.*]] = mesh.shard [[vsharded_9]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32>
- // CHECK: [[v2:%.*]] = tosa.matmul
- %3 = tosa.matmul %2, %sharded2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32>
- // CHECK: [[vsharded_12:%.*]] = mesh.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32>
- %s4 = mesh.sharding @mesh_1d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
- %4 = mesh.shard %3 to %s4 : tensor<2x4x8xf32>
- // CHECK: return [[vsharded_12]]
- return %4 : tensor<2x4x8xf32>
-}
-
-// https://arxiv.org/abs/2211.05102 Figure 2(b)
-// The sharding propagation results in unnecessary reshards,
-// an optimization pass should be able to remove them.
-// CHECK-LABEL: func.func @mlp_2d_weight_stationary
-// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32>
-func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
- // CHECK: [[vsharding:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !mesh.sharding
- %s0 = mesh.sharding @mesh_3d split_axes = [[], [], [0, 1, 2]] : !mesh.sharding
- // CHECK: [[vsharded:%.*]] = mesh.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32>
- %arg0_s = mesh.shard %arg0 to %s0 : tensor<2x4x8xf32>
- // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [0], [1, 2]] : !mesh.sharding
- %s1 = mesh.sharding @mesh_3d split_axes = [[], [0], [1, 2]] : !mesh.sharding
- // CHECK: [[vsharded_1:%.*]] = mesh.shard [[varg1]] to [[vsharding_0]] : tensor<2x8x32xf32>
- %arg1_s = mesh.shard %arg1 to %s1 : tensor<2x8x32xf32>
- // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK: [[vsharded_3:%.*]] = mesh.shard [[vsharded]] to [[vsharding_2]] annotate_for_users : tensor<2x4x8xf32>
- // CHECK: [[vsharded_4:%.*]] = mesh.shard [[vsharded_1]] to [[vsharding]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK: [[vsharded_5:%.*]] = mesh.shard [[varg3]] to [[vsharding_2]] annotate_for_users : tensor<1xf32>
- // CHECK: [[v0:%.*]] = tosa.matmul
- %1 = tosa.matmul %arg0_s, %arg1_s, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32>
- // CHECK: [[vsharded_6:%.*]] = mesh.shard [[v0]] to [[vsharding]] : tensor<2x4x32xf32>
- %2 = mesh.shard %1 to %s0 : tensor<2x4x32xf32>
- // CHECK: [[vsharded_7:%.*]] = mesh.shard [[vsharded_6]] to [[vsharding]] annotate_for_users : tensor<2x4x32xf32>
- // CHECK: [[v1:%.*]] = tosa.sigmoid
- // CHECK: [[vsharded_8:%.*]] = mesh.shard [[v1]] to [[vsharding]] : tensor<2x4x32xf32>
- %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
- // CHECK: [[vsharding_9:%.*]] = mesh.sharding @mesh_3d split_axes = {{\[\[}}], [1, 2], [0]] : !mesh.sharding
- %s2 = mesh.sharding @mesh_3d split_axes = [[], [1, 2], [0]] : !mesh.sharding
- // CHECK: [[vsharded_10:%.*]] = mesh.shard [[varg2]] to [[vsharding_9]] : tensor<2x32x8xf32>
- %arg2_s = mesh.shard %arg2 to %s2 : tensor<2x32x8xf32>
- // CHECK: [[vsharded_11:%.*]] = mesh.shard [[vsharded_8]] to [[vsharding_2]] annotate_for_users : tensor<2x4x32xf32>
- // CHECK: [[vsharded_12:%.*]] = mesh.shard [[vsharded_10]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32>
- // CHECK: [[v2:%.*]] = tosa.matmul
- %4 = tosa.matmul %3, %arg2_s, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32>
- // CHECK: [[vsharded_13:%.*]] = mesh.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32>
- %5 = mesh.shard %4 to %s0 : tensor<2x4x8xf32>
- // CHECK: [[vsharded_14:%.*]] = mesh.shard [[vsharded_13]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32>
- %6 = mesh.shard %5 to %s0 annotate_for_users : tensor<2x4x8xf32>
- // CHECK: return [[vsharded_14]]
- return %6 : tensor<2x4x8xf32>
-}
-
-// CHECK-LABEL: func.func @elementwise_duplicated_chain
-// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
-func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
- // CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
- // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
- %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
- // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V4:.*]] = tosa.sigmoid %[[V3]]
- %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V5:.*]] = mesh.shard %[[V4]] to %[[S0]] : tensor<8x16xf32>
- %s0 = mesh.sharding @mesh_2d split_axes = [[]] : !mesh.sharding
- %2 = mesh.shard %1 to %s0 : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V5]]
- return %2 : tensor<8x16xf32>
-}
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
deleted file mode 100644
index 701898cbdc74d..0000000000000
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ /dev/null
@@ -1,317 +0,0 @@
-// RUN: mlir-opt \
-// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
-// RUN: %s | FileCheck %s
-
-mesh.mesh @mesh_1d(shape = 2)
-
-// CHECK-LABEL: func @return_sharding
-func.func @return_sharding(
- // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32>
- %arg0: tensor<2xf32>
-// CHECK-SAME: ) -> (tensor<1xf32>, !mesh.sharding) {
-) -> (tensor<2xf32>, !mesh.sharding) {
- %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32>
- // CHECK-NEXT: [[vsharding:%.*]] = mesh.sharding @mesh_1d split_axes = {{\[\[}}0]] : !mesh.sharding
- %r = mesh.get_sharding %sharding_annotated : tensor<2xf32> -> !mesh.sharding
- // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !mesh.sharding
- return %sharding_annotated, %r : tensor<2xf32>, !mesh.sharding
-}
-
-// CHECK-LABEL: func @full_replication
-func.func @full_replication(
- // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
- %arg0: tensor<2xi8>
-// CHECK-SAME: -> tensor<2xi8> {
-) -> tensor<2xi8> {
- %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2xi8>
- %s1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8>
- // CHECK: return %[[ARG]] : tensor<2xi8>
- return %1 : tensor<2xi8>
-}
-
-// CHECK-LABEL: func @sharding_triplet
-func.func @sharding_triplet(
- // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32>
- %arg0: tensor<2xf32>
-// CHECK-SAME: ) -> tensor<2xf32> {
-) -> tensor<2xf32> {
- // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32>
- %ssharding_annotated = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated : tensor<2xf32>
- %ssharding_annotated_0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %sharding_annotated_0 = mesh.shard %sharding_annotated to %ssharding_annotated_0 annotate_for_users : tensor<2xf32>
- %ssharding_annotated_1 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1 : tensor<2xf32>
- // CHECK: return %[[ALL_GATHER]] : tensor<2xf32>
- return %sharding_annotated_1 : tensor<2xf32>
-}
-
-
-// CHECK-LABEL: func @move_split_axis
-func.func @move_split_axis(
- // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8>
- %arg0: tensor<2x2xi8>
-// CHECK-SAME: -> tensor<2x1xi8> {
-) -> tensor<2x2xi8> {
- // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d
- // CHECK-SAME: mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8>
- %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2x2xi8>
- %s1 = mesh.sharding @mesh_1d split_axes = [[], [0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2x2xi8>
- // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8>
- return %1 : tensor<2x2xi8>
-}
-
-// CHECK-LABEL: func @non_tensor_value
-func.func @non_tensor_value(
- // CHECK-SAME: %[[ARG:.*]]: i8
- %arg0: i8
-// CHECK-SAME: -> i8 {
-) -> i8 {
- // CHECK: %[[RES:.*]] = arith.addi %[[ARG]], %[[ARG]] : i8
- %0 = arith.addi %arg0, %arg0 : i8
- // CHECK: return %[[RES]] : i8
- return %0 : i8
-}
-
-// CHECK-LABEL: func @unary_elementwise
-func.func @unary_elementwise(
- // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8>
- %arg0: tensor<2xi8>
-// CHECK-SAME: -> tensor<1xi8> {
-) -> tensor<2xi8> {
- %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2xi8>
- %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8>
- // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8>
- %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
- %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %3 = mesh.shard %2 to %s3 : tensor<2xi8>
- %s4 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8>
- // CHECK: return %[[RES]] : tensor<1xi8>
- return %4 : tensor<2xi8>
-}
-
-// full replication -> shard axis -> abs -> shard axis -> full replication
-// CHECK-LABEL: func @unary_elementwise_with_resharding
-func.func @unary_elementwise_with_resharding(
- // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
- %arg0: tensor<2xi8>
-// CHECK-SAME: -> tensor<2xi8> {
-) -> tensor<2xi8> {
- // CHECK: %[[SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
- // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
- %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2xi8>
- %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8>
- // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
- %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
- // CHECK: %[[RES:.*]] = mesh.all_gather %[[ABS]] on @mesh_1d
- // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
- %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %3 = mesh.shard %2 to %s3 : tensor<2xi8>
- %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8>
- // CHECK: return %[[RES]] : tensor<2xi8>
- return %4 : tensor<2xi8>
-}
-
-// CHECK-LABEL: func @binary_elementwise
-func.func @binary_elementwise(
- // CHECK-SAME: %[[ARG0:.*]]: tensor<1xi8>,
- %arg0: tensor<2xi8>,
- // CHECK-SAME: %[[ARG1:.*]]: tensor<1xi8>
- %arg1: tensor<2xi8>
-// CHECK-SAME: -> tensor<1xi8> {
-) -> tensor<2xi8> {
- %sarg0_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %arg0_sharded = mesh.shard %arg0 to %sarg0_sharded : tensor<2xi8>
- %sop_arg0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %op_arg0 = mesh.shard %arg0_sharded to %sop_arg0 annotate_for_users : tensor<2xi8>
- %sarg1_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %arg1_sharded = mesh.shard %arg1 to %sarg1_sharded : tensor<2xi8>
- %sop_arg1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %op_arg1 = mesh.shard %arg1_sharded to %sop_arg1 annotate_for_users : tensor<2xi8>
- // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
- %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8>
- %sop_res_sharded = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %op_res_sharded = mesh.shard %op_res to %sop_res_sharded : tensor<2xi8>
- %sres = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %res = mesh.shard %op_res_sharded to %sres annotate_for_users : tensor<2xi8>
- // CHECK: return %[[RES]] : tensor<1xi8>
- return %res : tensor<2xi8>
-}
-
-// reshard
-// abs
-// reshard
-// abs
-// reshard
-// CHECK-LABEL: func @multiple_chained_ops
-func.func @multiple_chained_ops(
- // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
- %arg0: tensor<2xi8>
-// CHECK-SAME: -> tensor<1xi8> {
-) -> tensor<2xi8> {
- // CHECK: %[[RESHARD1:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
- // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
- %s0 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<2xi8>
- %s1 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<2xi8>
- // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
- %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
- // CHECK: %[[RESHARD2:.*]] = mesh.all_gather %[[ABS1]] on @mesh_1d
- // CHECK-SAME: mesh_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
- %s3 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %3 = mesh.shard %2 to %s3 : tensor<2xi8>
- %s4 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %4 = mesh.shard %3 to %s4 annotate_for_users : tensor<2xi8>
- // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
- %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
- // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 :
- // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
- %s6 = mesh.sharding @mesh_1d split_axes = [[]] : !mesh.sharding
- %6 = mesh.shard %5 to %s6 : tensor<2xi8>
- %s7 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %7 = mesh.shard %6 to %s7 annotate_for_users : tensor<2xi8>
- // CHECK: return %[[RESHARD3]] : tensor<1xi8>
- return %7 : tensor<2xi8>
-}
-
-// CHECK-LABEL: func @incomplete_sharding
-func.func @incomplete_sharding(
- // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
- %arg0: tensor<8x16xf32>
-// CHECK-SAME: -> tensor<4x16xf32> {
-) -> tensor<8x16xf32> {
- %s0 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32>
- // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
- %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- %s2 = mesh.sharding @mesh_1d split_axes = [[0]] : !mesh.sharding
- %2 = mesh.shard %1 to %s2 : tensor<8x16xf32>
- // CHECK: return %[[RES]] : tensor<4x16xf32>
- return %2 : tensor<8x16xf32>
-}
-
-mesh.mesh @mesh_1d_4(shape = 4)
-
-// CHECK-LABEL: func @ew_chain_with_halo
-func.func @ew_chain_with_halo(
- // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32>
- %arg0: tensor<8x16xf32>,
- // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xf32>
- %arg1: tensor<1xf32>,
- // CHECK-SAME: %[[IN3:[A-Za-z0-9_]+]]: tensor<1xf32>
- %arg2: tensor<1xf32>)
- // CHECK-SAME: -> tensor<5x16xf32>
- -> tensor<8x16xf32> {
- %ssharding_annotated = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
- %sharding_annotated = mesh.shard %arg0 to %ssharding_annotated annotate_for_users : tensor<8x16xf32>
- // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
- %0 = tosa.tanh %sharding_annotated : (tensor<8x16xf32>) -> tensor<8x16xf32>
- %ssharding_annotated_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
- %sharding_annotated_0 = mesh.shard %0 to %ssharding_annotated_0 : tensor<8x16xf32>
- %ssharding_annotated_1 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to %ssharding_annotated_1 annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
- %1 = tosa.abs %sharding_annotated_1 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- %ssharding_annotated_2 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
- %sharding_annotated_2 = mesh.shard %1 to %ssharding_annotated_2 : tensor<8x16xf32>
- %ssharding_annotated_4 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
- %sharding_annotated_4 = mesh.shard %sharding_annotated_2 to %ssharding_annotated_4 annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]], %[[IN2]], %[[IN3]] : (tensor<5x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x16xf32>
- %sharding_1 = mesh.sharding @mesh_1d_4 split_axes = [[]] : !mesh.sharding
- %zero_point_1 = mesh.shard %arg1 to %sharding_1 annotate_for_users : tensor<1xf32>
- %zero_point_2 = mesh.shard %arg2 to %sharding_1 annotate_for_users : tensor<1xf32>
- %2 = tosa.negate %sharding_annotated_4, %zero_point_1, %zero_point_2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
- %ssharding_annotated_5 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
- %sharding_annotated_5 = mesh.shard %2 to %ssharding_annotated_5 : tensor<8x16xf32>
- %ssharding_annotated_6 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !mesh.sharding
- %sharding_annotated_6 = mesh.shard %sharding_annotated_5 to %ssharding_annotated_6 annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
- return %sharding_annotated_6 : tensor<8x16xf32>
-}
-
-// CHECK-LABEL: func @test_shard_update_halo
-// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64>
-func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
- %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding
- // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
- // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
- // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64>
- %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
- %sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
- %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
- // CHECK: return %[[UH]] : tensor<304x1200xi64>
- return %sharding_annotated_3 : tensor<1200x1200xi64>
-}
-
-mesh.mesh @mesh4x4(shape = 4x4)
-// CHECK-LABEL: func @test_shard_update_halo2d
-// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64>
-func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
- %sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
- // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
- // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
- // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64>
- %sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
- %sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
- %sharding_annotated_3 = mesh.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
- // CHECK: return %[[UH]] : tensor<303x307xi64>
- return %sharding_annotated_3 : tensor<1200x1200xi64>
-}
-
-mesh.mesh @mesh(shape = 2)
-// CHECK-LABEL: func.func @test_reduce_0d(
-// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32>
-func.func @test_reduce_0d(%arg0: tensor<6x6xi32>) -> (tensor<i32>) {
- %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
- %sharded = mesh.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
- %4 = tensor.empty() : tensor<i32>
- %sharding_out = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
- %sharded_out = mesh.shard %4 to %sharding_out : tensor<i32>
- %sharded_in = mesh.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
- // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
- %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<i32>) dimensions = [0, 1]
- (%in: i32, %init: i32) {
- %6 = arith.addi %in, %init : i32
- linalg.yield %6 : i32
- }
- // CHECK: %[[all_reduce:.*]] = mesh.all_reduce %[[reduced]] on @mesh mesh_axes = [0] : tensor<i32> -> tensor<i32>
- %sharded_red = mesh.shard %reduced to %sharding_out : tensor<i32>
- %sharded_ret = mesh.shard %sharded_red to %sharding_out annotate_for_users : tensor<i32>
- // CHECK: return %[[all_reduce]] : tensor<i32>
- return %sharded_ret : tensor<i32>
-}
-
-// CHECK-LABEL: func.func @test_reduce_1d(
-// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32>
-func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) {
- %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
- %sharded = mesh.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
- %4 = tensor.empty() : tensor<6xi32>
- %sharded_out = mesh.shard %4 to %sharding : tensor<6xi32>
- %sharded_in = mesh.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
- // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
- %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<6xi32>) dimensions = [1]
- (%in: i32, %init: i32) {
- %6 = arith.addi %in, %init : i32
- linalg.yield %6 : i32
- }
- // CHECK-NOT: mesh.all_reduce
- %sharded_red = mesh.shard %reduced to %sharding : tensor<6xi32>
- %sharded_ret = mesh.shard %sharded_red to %sharding annotate_for_users : tensor<6xi32>
- // CHECK: return %[[reduced]] : tensor<3xi32>
- return %sharded_ret : tensor<6xi32>
-}
diff --git a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir
similarity index 72%
rename from mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
rename to mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir
index 4f54607a1c7ff..bc911215851aa 100644
--- a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
+++ b/mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir
@@ -1,43 +1,43 @@
-// RUN: mlir-opt --split-input-file --test-mesh-all-slice-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --test-grid-all-slice-op-lowering --test-grid-simplifications --cse %s | FileCheck %s
-mesh.mesh @mesh_1d(shape = ?)
+shard.grid @grid_1d(shape = ?)
-// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh
-func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_grid
+func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_grid(
// CHECK: %[[ARG:.*]]: tensor<?xf16>
%arg0: tensor<?xf16>
// CHECK-SAME: -> tensor<?xf16> {
) -> tensor<?xf16> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
- // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+ // CHECK-DAG: %[[PROC_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index
+ // CHECK-DAG: %[[SHARD_SIZE:.*]] = shard.grid_shape @grid_1d axes = [0] : index
// CHECK: %[[TENSOR_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %c0 : tensor<?xf16>
- // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
+ // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[SHARD_SIZE]] : index
// CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
// CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
- // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
+ // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[SHARD_SIZE]] : index
// CHECK: %[[SLICE_OFFSET:.*]] = arith.muli %[[PROC_IDX]], %[[RESULT_AXIS_SIZE]] : index
// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][%[[SLICE_OFFSET]]] [%[[RESULT_AXIS_SIZE]]] [1] : tensor<?xf16> to tensor<?xf16>
- %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<?xf16> -> tensor<?xf16>
+ %0 = shard.all_slice %arg0 on @grid_1d grid_axes = [0] slice_axis = 0 : tensor<?xf16> -> tensor<?xf16>
// CHECK: return %[[RESULT]] : tensor<?xf16>
return %0 : tensor<?xf16>
}
// -----
-mesh.mesh @mesh_1d(shape = 2)
+shard.grid @grid_1d(shape = 2)
-// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh
-func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_grid
+func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_grid(
// CHECK: %[[ARG:.*]]: tensor<2xf16>
%arg0: tensor<2xf16>
// CHECK-SAME: -> tensor<1xf16> {
) -> tensor<1xf16> {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
- // CHECK: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+ // CHECK: %[[PROC_IDX:.*]] = shard.process_multi_index on @grid_1d axes = [0] : index
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[PROC_IDX]]] [%[[C1]]] [1] : tensor<2xf16> to tensor<?xf16>
// CHECK: %[[RESULT:.*]] = tensor.cast %[[SLICE]] : tensor<?xf16> to tensor<1xf16>
- %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16>
+ %0 = shard.all_slice %arg0 on @grid_1d grid_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16>
// CHECK: return %[[RESULT]] : tensor<1xf16>
return %0 : tensor<1xf16>
}
@@ -46,18 +46,18 @@ func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
-mesh.mesh @mesh_4d(shape = ?x?x?x?)
+shard.grid @grid_4d(shape = ?x?x?x?)
-// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh
-func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid
+func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_grid(
// CHECK: %[[ARG:.*]]: tensor<?x?xf16>
%arg0 : tensor<?x?xf16>
// CHECK-SAME: -> tensor<?x?xf16> {
) -> tensor<?x?xf16> {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = mesh.process_multi_index on @mesh_4d axes = [3, 1] : index, index
- // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = mesh.mesh_shape @mesh_4d axes = [3, 1] : index, index
+ // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = shard.process_multi_index on @grid_4d axes = [3, 1] : index, index
+ // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = shard.grid_shape @grid_4d axes = [3, 1] : index, index
// CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index
// CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
// CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
@@ -68,7 +68,7 @@ func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
// CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16>
// CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index
// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
- %0 = mesh.all_slice %arg0 on @mesh_4d mesh_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
+ %0 = shard.all_slice %arg0 on @grid_4d grid_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
// CHECK: return %[[RESULT]] : tensor<?x?xf16>
return %0 : tensor<?x?xf16>
}
diff --git a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/backward-sharding-propagation.mlir
similarity index 76%
rename from mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
rename to mlir/test/Dialect/Shard/backward-sharding-propagation.mlir
index 4223d01d65111..dd3bc3a5a1c94 100644
--- a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
+++ b/mlir/test/Dialect/Shard/backward-sharding-propagation.mlir
@@ -2,17 +2,17 @@
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+ shard.grid @grid(shape = 1) {sym_visibility = "private"}
func.func @test_forward() -> tensor<6x6xi32> {
%c1_i32 = arith.constant 1 : i32
// CHECK: tensor.empty()
%0 = tensor.empty() : tensor<6x6xi32>
- %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
- // CHECK-COUNT-2: mesh.shard
- %sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32>
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ // CHECK-COUNT-2: shard.shard
+ %sharding_annotated = shard.shard %0 to %sharding : tensor<6x6xi32>
%1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32>
// CHECK: tensor.empty()
- // CHECK-NOT: mesh.shard @
+ // CHECK-NOT: shard.shard @
%2 = tensor.empty() : tensor<6x6xi32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1
: tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
diff --git a/mlir/test/Dialect/Shard/canonicalization.mlir b/mlir/test/Dialect/Shard/canonicalization.mlir
new file mode 100644
index 0000000000000..4aadb581baf7f
--- /dev/null
+++ b/mlir/test/Dialect/Shard/canonicalization.mlir
@@ -0,0 +1,248 @@
+// RUN: mlir-opt --canonicalize %s | FileCheck %s
+
+shard.grid @grid0(shape = 2x4)
+
+// CHECK-LABEL: func @all_reduce_empty_grid_axes
+func.func @all_reduce_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.all_reduce
+ %0 = shard.all_reduce %arg0 on @grid0
+ grid_axes = []
+ : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @all_reduce_empty_grid_axes_different_return_type
+func.func @all_reduce_empty_grid_axes_different_return_type(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: shard.all_reduce
+ %0 = shard.all_reduce %arg0 on @grid0
+// CHECK-NOT: grid_axes
+ grid_axes = []
+ : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_reduce_default_reduction
+func.func @all_reduce_default_reduction(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+ %0 = shard.all_reduce %arg0 on @grid0
+ grid_axes = [0]
+// CHECK-NOT: reduction
+ reduction = sum
+ : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @all_to_all_empty_grid_axes
+func.func @all_to_all_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32>
+ %arg0 : tensor<8xf32>) -> tensor<8xf32> {
+// CHECK-NOT: shard.all_to_all
+ %0 = shard.all_to_all %arg0 on @grid0
+ grid_axes = []
+ split_axis = 0
+ concat_axis = 0
+ : tensor<8xf32> -> tensor<8xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @all_gather_empty_grid_axes
+func.func @all_gather_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.all_gather
+ %0 = shard.all_gather %arg0 on @grid0
+ grid_axes = []
+ gather_axis = 0
+ : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @all_slice_empty_grid_axes
+func.func @all_slice_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.scatter
+ %0 = shard.all_slice %arg0 on @grid0
+ grid_axes = []
+ slice_axis = 0
+ : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @broadcast_empty_grid_axes
+func.func @broadcast_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.broadcast
+ %0 = shard.broadcast %arg0 on @grid0
+ grid_axes = []
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @gather_empty_grid_axes
+func.func @gather_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.gather
+ %0 = shard.gather %arg0 on @grid0
+ grid_axes = []
+ gather_axis = 0
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @receive_empty_grid_axes
+func.func @receive_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.recv
+ %0 = shard.recv %arg0 on @grid0
+ grid_axes = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_empty_grid_axes
+func.func @reduce_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.reduce
+ %0 = shard.reduce %arg0 on @grid0
+ grid_axes = []
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_grid_axes
+func.func @reduce_scatter_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.reduce_scatter
+ %0 = shard.reduce_scatter %arg0 on @grid0
+ grid_axes = []
+ scatter_axis = 0
+ : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_scatter_empty_grid_axes_different_return_type
+func.func @reduce_scatter_empty_grid_axes_different_return_type(
+ %arg0 : tensor<4xf32>) -> tensor<4xf64> {
+// CHECK: shard.reduce_scatter
+ %0 = shard.reduce_scatter %arg0 on @grid0
+// CHECK-NOT: grid_axes
+ grid_axes = []
+ scatter_axis = 0
+ : tensor<4xf32> -> tensor<4xf64>
+ return %0 : tensor<4xf64>
+}
+
+// CHECK-LABEL: func @reduce_scatter_default_reduction
+func.func @reduce_scatter_default_reduction(
+ %arg0 : tensor<4xf32>) -> tensor<2xf64> {
+ %0 = shard.reduce_scatter %arg0 on @grid0
+ grid_axes = [0]
+// CHECK-NOT: reduction
+ reduction = sum
+ scatter_axis = 0
+ : tensor<4xf32> -> tensor<2xf64>
+ return %0 : tensor<2xf64>
+}
+
+// CHECK-LABEL: func @scatter_empty_grid_axes
+func.func @scatter_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.scatter
+ %0 = shard.scatter %arg0 on @grid0
+ grid_axes = []
+ scatter_axis = 0
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @send_empty_grid_axes
+func.func @send_empty_grid_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: shard.send
+ %0 = shard.send %arg0 on @grid0
+ grid_axes = []
+ destination = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+shard.grid @grid4x4(shape = 4x4)
+// CHECK-LABEL: func @test_halo_sizes
+func.func @test_halo_sizes() -> !shard.sharding {
+ %c2_i64 = arith.constant 2 : i64
+ // CHECK shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 2, 22] : !shard.sharding
+ %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, %c2_i64, %c2_i64, 22] : !shard.sharding
+ return %sharding : !shard.sharding
+}
+
+// CHECK-LABEL: func @test_shard_offs
+func.func @test_shard_offs() -> !shard.sharding {
+ %c2_i64 = arith.constant 2 : i64
+ // CHECK shard.sharding @grid4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, 2, 3, 4, 0, 2, 3, 4, 22] : !shard.sharding
+ %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] sharded_dims_offsets = [0, 1, %c2_i64, 3, 4, 0, %c2_i64, 3, 4, 22] : !shard.sharding
+ return %sharding : !shard.sharding
+}
+
+// CHECK-LABEL: func @test_duplicate_shardops
+func.func @test_duplicate_shardops() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+ // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0, 1]] : !shard.sharding
+ %sharding_1 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding
+ %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %sharding_2 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding
+ %sharding_annotated_2 = shard.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
+ %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %sharding_3 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding
+ %sharding_annotated_3 = shard.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharding_annotated:%.*]] = shard.shard [[vcst]] to [[vsharding]] : tensor<1024x1024xf32>
+ %sharding_annotated_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+ // CHECK-NEXT: return [[vsharding_annotated]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+ return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+}
+
+// CHECK-LABEL: func @test_duplicate_shardops_diff
+func.func @test_duplicate_shardops_diff() -> (tensor<1024x1024xf32>, tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
+ // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %cst_1 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0]] : !shard.sharding
+ %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+ %cst_2 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharding_0:%.*]] = shard.sharding @grid4x4 split_axes = {{\[\[}}0, 1]] : !shard.sharding
+ %sharding_2 = shard.sharding @grid4x4 split_axes = [[0, 1]] : !shard.sharding
+ // CHECK-NEXT: [[vsharding_annotated:%.*]] = shard.shard [[vcst]] to [[vsharding_0]] : tensor<1024x1024xf32>
+ %sharding_annotated_2 = shard.shard %cst_2 to %sharding_2 : tensor<1024x1024xf32>
+ %cst_3 = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
+ %sharding_3 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
+ %sharding_annotated_3 = shard.shard %cst_3 to %sharding_3 : tensor<1024x1024xf32>
+ // CHECK-NEXT: [[vsharding_annotated_1:%.*]] = shard.shard [[vsharding_annotated]] to [[vsharding]] : tensor<1024x1024xf32>
+ %sharding_annotated_1 = shard.shard %cst_1 to %sharding_1 : tensor<1024x1024xf32>
+ // CHECK-NEXT: return [[vsharding_annotated_1]], [[vsharding_annotated]] : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+ return %sharding_annotated_1, %sharding_annotated_2 : tensor<1024x1024xf32>, tensor<1024x1024xf32>
+}
diff --git a/mlir/test/Dialect/Shard/folding.mlir b/mlir/test/Dialect/Shard/folding.mlir
new file mode 100644
index 0000000000000..5a0f35b53a129
--- /dev/null
+++ b/mlir/test/Dialect/Shard/folding.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
+
+shard.grid @grid0(shape = 4x?x2)
+shard.grid @grid1(shape = 2x3)
+
+// CHECK-LABEL: func.func @grid_shape_op_folding
+func.func @grid_shape_op_folding() -> (index, index) {
+ // CHECK: %[[AXIS_2_SIZE:.*]] = arith.constant 2 : index
+ // CHECK: %[[AXIS_1_SIZE:.*]] = shard.grid_shape @grid0 axes = [1] : index
+ %0:2 = shard.grid_shape @grid0 axes = [2, 1] : index, index
+ // CHECK: return %[[AXIS_2_SIZE]], %[[AXIS_1_SIZE]]
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @grid_shape_op_folding_all_axes_static_grid
+func.func @grid_shape_op_folding_all_axes_static_grid() -> (index, index) {
+ // CHECK: %[[AXIS_0_SIZE:.*]] = arith.constant 2 : index
+ // CHECK: %[[AXIS_1_SIZE:.*]] = arith.constant 3 : index
+ %0:2 = shard.grid_shape @grid1 : index, index
+ // CHECK: return %[[AXIS_0_SIZE]], %[[AXIS_1_SIZE]]
+ return %0#0, %0#1 : index, index
+}
diff --git a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/forward-backward-sharding-propagation.mlir
similarity index 63%
rename from mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
rename to mlir/test/Dialect/Shard/forward-backward-sharding-propagation.mlir
index dd2eee2f7def8..0d8d99752620a 100644
--- a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
+++ b/mlir/test/Dialect/Shard/forward-backward-sharding-propagation.mlir
@@ -2,25 +2,25 @@
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
- mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+ shard.grid @grid(shape = 1) {sym_visibility = "private"}
func.func @test_forward() -> tensor<6x6xi32> {
%c1_i32 = arith.constant 1 : i32
// CHECK: tensor.empty()
%0 = tensor.empty() : tensor<6x6xi32>
- // CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]]
- %sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
- %annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32>
+ // CHECK-COUNT-3: shard.sharding @grid split_axes = {{\[\[0}}]]
+ %sharding_row = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %annotated_row = shard.shard %0 to %sharding_row : tensor<6x6xi32>
%1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32>
%2 = tensor.empty() : tensor<6x6xi32>
- // CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]]
+ // CHECK-COUNT-4: shard.sharding @grid split_axes = {{\[\[1}}]]
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1
: tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
^bb0(%in: i32, %in_2: i32, %out: i32):
%9 = arith.addi %in, %in_2 : i32
linalg.yield %9 : i32
} -> tensor<6x6xi32>
- %sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding
- %annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32>
+ %sharding_col = shard.sharding @grid split_axes = [[1]] : !shard.sharding
+ %annotated_col = shard.shard %3 to %sharding_col : tensor<6x6xi32>
// CHECK: return
return %annotated_col : tensor<6x6xi32>
}
diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Shard/forward-sharding-propagation.mlir
similarity index 53%
rename from mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
rename to mlir/test/Dialect/Shard/forward-sharding-propagation.mlir
index 6ab711b1b653c..e1894e5f7d4ac 100644
--- a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
+++ b/mlir/test/Dialect/Shard/forward-sharding-propagation.mlir
@@ -2,27 +2,27 @@
#map = affine_map<(d0, d1) -> (d0, d1)>
module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
- mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+ shard.grid @grid(shape = 1) {sym_visibility = "private"}
func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
%c1_i32 = arith.constant 1 : i32
// CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
%0 = tensor.empty() : tensor<6x6xi32>
// CHECK: [[v1:%.*]] = linalg.fill ins
- // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
+ // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding
+ // CHECK: [[vsharding_annotated_1:%.*]] = shard.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
%1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
- %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
- %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %sharding_annotated = shard.shard %1 to %sharding : tensor<6x6xi32>
// CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
- // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
+ // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding
+ // CHECK: [[vsharding_annotated_3:%.*]] = shard.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
%3 = tensor.empty() : tensor<6x6xi32>
- // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
+ // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding
+ // CHECK: [[vsharding_annotated_5:%.*]] = shard.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
// CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
- // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
+ // CHECK: [[vsharding_6:%.*]] = shard.sharding @grid split_axes = {{\[\[}}0]] : !shard.sharding
+ // CHECK: [[vsharding_annotated_7:%.*]] = shard.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
: tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
^bb0(%in: i32, %in_2: i32, %out: i32):
@@ -33,17 +33,17 @@ module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:com
%6 = tensor.empty() : tensor<i32>
%7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
// CHECK: [[vreduced:%.*]] = linalg.reduce ins
- // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] : !mesh.sharding
- // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
+ // CHECK: [[vsharding_12:%.*]] = shard.sharding @grid split_axes = [] : !shard.sharding
+ // CHECK: [[vsharding_annotated_13:%.*]] = shard.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
%reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1]
(%in: i32, %init: i32) {
%9 = arith.addi %in, %init : i32
linalg.yield %9 : i32
}
- // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
- %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
- // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
- %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
+ // CHECK: [[vsharding_14:%.*]] = shard.sharding @grid split_axes = {{\[\[}}]] : !shard.sharding
+ %sharding_0 = shard.sharding @grid split_axes = [[]] : !shard.sharding
+ // CHECK: [[vsharding_annotated_15:%.*]] = shard.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
+ %sharding_annotated_1 = shard.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
}
}
diff --git a/mlir/test/Dialect/Shard/inlining.mlir b/mlir/test/Dialect/Shard/inlining.mlir
new file mode 100644
index 0000000000000..ce664b31abf7a
--- /dev/null
+++ b/mlir/test/Dialect/Shard/inlining.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt -inline %s | FileCheck %s
+
+shard.grid @grid0(shape = 4x?x2)
+
+func.func private @grid_to_inline() -> (index, index) {
+ %0:2 = shard.grid_shape @grid0 axes = [2, 1] : index, index
+ return %0#0, %0#1 : index, index
+}
+// CHECK-LABEL: func.func @main
+func.func @main() -> (index, index) {
+ // CHECK-NEXT: %[[AXIS_SIZE:.*]]:2 = shard.grid_shape @grid0 axes = [2, 1] : index
+ %0:2 = func.call @grid_to_inline() : () -> (index, index)
+ // CHECK-NEXT: return %[[AXIS_SIZE]]#0, %[[AXIS_SIZE]]#1
+ return %0#0, %0#1 : index, index
+}
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Shard/invalid.mlir
similarity index 57%
rename from mlir/test/Dialect/Mesh/invalid.mlir
rename to mlir/test/Dialect/Shard/invalid.mlir
index 2656332942382..6acac971164ed 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Shard/invalid.mlir
@@ -1,55 +1,55 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
-// expected-error at +1 {{rank of mesh is expected to be a positive integer}}
-mesh.mesh @mesh0(shape = [])
+// expected-error at +1 {{rank of grid is expected to be a positive integer}}
+shard.grid @grid0(shape = [])
// -----
-// expected-error at +1 {{custom op 'mesh.mesh' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
-mesh.mesh @mesh0(shape = -1)
+// expected-error at +1 {{custom op 'shard.grid' Failed parsing dimension list. Did you mean an empty list? It must be denoted by "[]".}}
+shard.grid @grid0(shape = -1)
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @mesh_axis_duplicated_different_subarray(
+func.func @grid_axis_duplicated_different_subarray(
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // expected-error at +1 {{mesh axis duplicated}}
- %s = mesh.sharding @mesh0 split_axes = [[0], [0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ // expected-error at +1 {{grid axis duplicated}}
+ %s = shard.sharding @grid0 split_axes = [[0], [0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @mesh_axis_duplicated_same_subarray(
+func.func @grid_axis_duplicated_same_subarray(
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // expected-error at +1 {{mesh axis duplicated}}
- %s = mesh.sharding @mesh0 split_axes = [[0, 0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ // expected-error at +1 {{grid axis duplicated}}
+ %s = shard.sharding @grid0 split_axes = [[0, 0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @mesh_axis_negtive_in_split_part(
+func.func @grid_axis_negtive_in_split_part(
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // expected-error at +1 {{mesh axis is expected to be non-negative}}
- %s = mesh.sharding @mesh0 split_axes = [[-1]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ // expected-error at +1 {{grid axis is expected to be non-negative}}
+ %s = shard.sharding @grid0 split_axes = [[-1]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// -----
func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
- // expected-error at +1 {{custom op 'mesh.sharding' invalid kind of attribute specified}}
- %s = mesh.sharding @a::@b split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ // expected-error at +1 {{custom op 'shard.sharding' invalid kind of attribute specified}}
+ %s = shard.sharding @a::@b split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return
}
@@ -57,8 +57,8 @@ func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
// expected-error at +1 {{halo sizes must be specified for all split axes}}
- %s = mesh.sharding @mesh0 split_axes = [[0], [1]] halo_sizes = [1, 2] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ %s = shard.sharding @grid0 split_axes = [[0], [1]] halo_sizes = [1, 2] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return
}
@@ -66,292 +66,292 @@ func.func @sharding_attribute_invalid_halo(%arg0 : tensor<4x8xf32>) {
func.func @sharding_attribute_invalid_sizes(%arg0 : tensor<4x8xf32>) {
// expected-error at +1 {{halo sizes and shard offsets are mutually exclusive}}
- %s = mesh.sharding @mesh0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ %s = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2] sharded_dims_offsets = [0, 2, 2] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return
}
// -----
-mesh.mesh @mesh_dyn(shape = ?x?)
-func.func @sharding_dyn_mesh_and_sizes(%arg0 : tensor<4x8xf32>) {
- // expected-error at +1 {{sharded dims offsets are not allowed for devices meshes with dynamic shape}}
- %s = mesh.sharding @mesh_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+shard.grid @grid_dyn(shape = ?x?)
+func.func @sharding_dyn_grid_and_sizes(%arg0 : tensor<4x8xf32>) {
+ // expected-error at +1 {{sharded dims offsets are not allowed for device grids with dynamic shape}}
+ %s = shard.sharding @grid_dyn split_axes = [[0]] sharded_dims_offsets = [0, 2, 2] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
func.func @sharding_sizes_count(%arg0 : tensor<4x8xf32>) {
// expected-error at +1 {{sharded dims offsets has wrong size}}
- %s = mesh.sharding @mesh0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ %s = shard.sharding @grid0 split_axes = [[0], [1]] sharded_dims_offsets = [0, 2, 4, 0, 2, 4, 6] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return
}
// -----
-mesh.mesh @mesh0(shape = 4)
+shard.grid @grid0(shape = 4)
func.func @sharding_sizes_decreasing(%arg0 : tensor<4x8xf32>) {
// expected-error at +1 {{sharded dims offsets must be non-decreasing}}
- %s = mesh.sharding @mesh0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ %s = shard.sharding @grid0 split_axes = [[0]] sharded_dims_offsets = [0, 2, 3, 2] : !shard.sharding
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @mesh_shape_mesh_axis_out_of_bounds() -> (index, index) {
- // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
- %0:2 = mesh.mesh_shape @mesh0 axes = [0, 2] : index, index
+func.func @grid_shape_grid_axis_out_of_bounds() -> (index, index) {
+ // expected-error at +1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+ %0:2 = shard.grid_shape @grid0 axes = [0, 2] : index, index
return %0#0, %0#1 : index, index
}
// -----
-mesh.mesh @mesh0(shape = 1x2x3)
+shard.grid @grid0(shape = 1x2x3)
-func.func @mesh_shape_duplicate_mesh_axis() -> (index, index, index) {
- // expected-error at +1 {{Mesh axes contains duplicate elements.}}
- %0:3 = mesh.mesh_shape @mesh0 axes = [0, 2, 0] : index, index, index
+func.func @grid_shape_duplicate_grid_axis() -> (index, index, index) {
+ // expected-error at +1 {{Grid axes contains duplicate elements.}}
+ %0:3 = shard.grid_shape @grid0 axes = [0, 2, 0] : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @mesh_shape_wrong_number_of_results() -> (index, index) {
+func.func @grid_shape_wrong_number_of_results() -> (index, index) {
// expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
- %0:2 = mesh.mesh_shape @mesh0 axes = [0] : index, index
+ %0:2 = shard.grid_shape @grid0 axes = [0] : index, index
return %0#0, %0#1 : index, index
}
// -----
-mesh.mesh @mesh0(shape = 1x2x3)
+shard.grid @grid0(shape = 1x2x3)
-func.func @mesh_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+func.func @grid_shape_wrong_number_of_results_empty_grid_axes() -> (index, index) {
// expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
- %0:2 = mesh.mesh_shape @mesh0 : index, index
+ %0:2 = shard.grid_shape @grid0 : index, index
return %0#0, %0#1 : index, index
}
// -----
-func.func @mesh_shape_invalid_mesh_name() -> (index) {
- // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.mesh_shape @this_mesh_symbol_does_not_exist : index
+func.func @grid_shape_invalid_grid_name() -> (index) {
+ // expected-error at +1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+ %0 = shard.grid_shape @this_grid_symbol_does_not_exist : index
return %0#0 : index
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @process_multi_index_mesh_axis_out_of_bounds() -> (index, index) {
- // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
- %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 2] : index, index
+func.func @process_multi_index_grid_axis_out_of_bounds() -> (index, index) {
+ // expected-error at +1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+ %0:2 = shard.process_multi_index on @grid0 axes = [0, 2] : index, index
return %0#0, %0#1 : index, index
}
// -----
-mesh.mesh @mesh0(shape = 1x2x3)
+shard.grid @grid0(shape = 1x2x3)
-func.func @process_multi_index_duplicate_mesh_axis() -> (index, index, index) {
- // expected-error at +1 {{Mesh axes contains duplicate elements.}}
- %0:3 = mesh.process_multi_index on @mesh0 axes = [0, 2, 0] : index, index, index
+func.func @process_multi_index_duplicate_grid_axis() -> (index, index, index) {
+ // expected-error at +1 {{Grid axes contains duplicate elements.}}
+ %0:3 = shard.process_multi_index on @grid0 axes = [0, 2, 0] : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
func.func @process_multi_index_wrong_number_of_results() -> (index, index) {
// expected-error at +1 {{Unexpected number of results 2. Expected 1.}}
- %0:2 = mesh.process_multi_index on @mesh0 axes = [0] : index, index
+ %0:2 = shard.process_multi_index on @grid0 axes = [0] : index, index
return %0#0, %0#1 : index, index
}
// -----
-mesh.mesh @mesh0(shape = 1x2x3)
+shard.grid @grid0(shape = 1x2x3)
-func.func @process_multi_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+func.func @process_multi_index_wrong_number_of_results_empty_grid_axes() -> (index, index) {
// expected-error at +1 {{Unexpected number of results 2. Expected 3.}}
- %0:2 = mesh.process_multi_index on @mesh0 : index, index
+ %0:2 = shard.process_multi_index on @grid0 : index, index
return %0#0, %0#1 : index, index
}
// -----
-func.func @process_multi_index_invalid_mesh_name() -> (index) {
- // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.process_multi_index on @this_mesh_symbol_does_not_exist : index
+func.func @process_multi_index_invalid_grid_name() -> (index) {
+ // expected-error at +1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+ %0 = shard.process_multi_index on @this_grid_symbol_does_not_exist : index
return %0 : index
}
// -----
-func.func @process_linear_index_invalid_mesh_name() -> (index) {
- // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.process_linear_index on @this_mesh_symbol_does_not_exist : index
+func.func @process_linear_index_invalid_grid_name() -> (index) {
+ // expected-error at +1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+ %0 = shard.process_linear_index on @this_grid_symbol_does_not_exist : index
return %0 : index
}
// -----
-func.func @all_reduce_invalid_mesh_symbol(
+func.func @all_reduce_invalid_grid_symbol(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
- // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.all_reduce %arg0 on @this_mesh_symbol_does_not_exist reduction = sum
+ // expected-error at +1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+ %0 = shard.all_reduce %arg0 on @this_grid_symbol_does_not_exist reduction = sum
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @all_reduce_invalid_mesh_axis(
+func.func @all_reduce_invalid_grid_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
- // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [2] reduction = sum
+ // expected-error at +1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [2] reduction = sum
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @all_reduce_duplicate_mesh_axis(
+func.func @all_reduce_duplicate_grid_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
- // expected-error at +1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0, 1, 0] reduction = sum
+ // expected-error at +1 {{Grid axes contains duplicate elements.}}
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1, 0] reduction = sum
: tensor<4xf32> -> tensor<4xf64>
return %0 : tensor<4xf64>
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
func.func @all_reduce_invalid_tensor_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<5xf64> {
- // expected-error at +1 {{'mesh.all_reduce' op requires the same shape for all operands and results}}
- %0 = mesh.all_reduce %arg0 on @mesh0 : tensor<4xf32> -> tensor<5xf64>
+ // expected-error at +1 {{'shard.all_reduce' op requires the same shape for all operands and results}}
+ %0 = shard.all_reduce %arg0 on @grid0 : tensor<4xf32> -> tensor<5xf64>
return %0 : tensor<5xf64>
}
// -----
-func.func @all_gather_invalid_mesh_symbol(
+func.func @all_gather_invalid_grid_symbol(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
- // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.all_gather %arg0 on @this_mesh_symbol_does_not_exist gather_axis = 0
+ // expected-error at +1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+ %0 = shard.all_gather %arg0 on @this_grid_symbol_does_not_exist gather_axis = 0
: tensor<4xf32> -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @all_gather_invalid_mesh_axis(
+func.func @all_gather_invalid_grid_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
- // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 0
+ // expected-error at +1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 0
: tensor<4xf32> -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @all_reduce_duplicate_mesh_axis(
+func.func @all_reduce_duplicate_grid_axis(
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
- // expected-error at +1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2, 2] gather_axis = 0
+ // expected-error at +1 {{Grid axes contains duplicate elements.}}
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2, 2] gather_axis = 0
: tensor<4xf32> -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @all_gather_invalid_non_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
// expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
: tensor<3x4xf32> -> tensor<3x5xf32>
return %0 : tensor<3x5xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1x2)
+shard.grid @grid0(shape = 1x2)
func.func @all_gather_invalid_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
// expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [1] gather_axis = 1
: tensor<3x4xf32> -> tensor<3x5xf32>
return %0 : tensor<3x5xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @all_gather_invalid_gather_axis_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<3xf32> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
- %0 = mesh.all_gather %arg0 on @mesh0 gather_axis = 0
+ %0 = shard.all_gather %arg0 on @grid0 gather_axis = 0
: tensor<?xf32> -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @all_gather_invalid_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
// expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = 1
: tensor<3xf32> -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @all_gather_invalid_negative_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
// expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [0] gather_axis = -1
: tensor<3xf32> -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
-func.func @all_slice_duplicate_mesh_axis(
+func.func @all_slice_duplicate_grid_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
- // expected-error at +1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0, 0]
+ // expected-error at +1 {{Grid axes contains duplicate elements.}}
+ %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0, 0]
slice_axis = 0
: tensor<?xf32> -> tensor<?xf32>
return %0 : tensor<?xf32>
@@ -359,12 +359,12 @@ func.func @all_slice_duplicate_mesh_axis(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @all_slice_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf32> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
- %0 = mesh.all_slice %arg0 on @mesh0
+ %0 = shard.all_slice %arg0 on @grid0
slice_axis = 0
: tensor<?xf32> -> tensor<2xf32>
return %0 : tensor<2xf32>
@@ -372,12 +372,12 @@ func.func @all_slice_invalid_dynamic_dimension(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @all_slice_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf32> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
- %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0]
slice_axis = 0
: tensor<3xf32> -> tensor<2xf32>
return %0 : tensor<2xf32>
@@ -385,12 +385,12 @@ func.func @all_slice_invalid_static_dimension_size(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @all_slice_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf32> {
// expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
- %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.all_slice %arg0 on @grid0 grid_axes = [0]
slice_axis = 0
: tensor<4xf32> -> tensor<?xf32>
return %0 : tensor<?xf32>
@@ -398,10 +398,10 @@ func.func @all_slice_invalid_operand_static_dimension_size(
// -----
-func.func @all_to_all_invalid_mesh_symbol(
+func.func @all_to_all_invalid_grid_symbol(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
- // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.all_to_all %arg0 on @this_mesh_symbol_does_not_exist
+ // expected-error at +1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+ %0 = shard.all_to_all %arg0 on @this_grid_symbol_does_not_exist
split_axis = 1 concat_axis = 0
: tensor<3x6xi8> -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
@@ -409,12 +409,12 @@ func.func @all_to_all_invalid_mesh_symbol(
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
-func.func @all_to_all_duplicate_mesh_axis(
+func.func @all_to_all_duplicate_grid_axis(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
- // expected-error at +1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 0]
+ // expected-error at +1 {{Grid axes contains duplicate elements.}}
+ %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0, 0]
split_axis = 0 concat_axis = 0
: tensor<3x6xi8> -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
@@ -422,12 +422,12 @@ func.func @all_to_all_duplicate_mesh_axis(
// -----
-mesh.mesh @mesh0(shape = ?x1)
+shard.grid @grid0(shape = ?x1)
func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_device_group(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
// expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 6.}}
- %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0]
split_axis = 0 concat_axis = 1
: tensor<3x6xi8> -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
@@ -435,12 +435,12 @@ func.func @all_to_all_invalid_non_dynamic_result_dimension_induced_by_dynamic_de
// -----
-mesh.mesh @mesh0(shape = 1x1)
+shard.grid @grid0(shape = 1x1)
func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dynamic_operand_dimension(
%arg0 : tensor<?x6xi8>) -> tensor<3x?xi8> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
- %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
+ %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [1]
split_axis = 0 concat_axis = 1
: tensor<?x6xi8> -> tensor<3x?xi8>
return %0 : tensor<3x?xi8>
@@ -448,12 +448,12 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_induced_by_dyna
// -----
-mesh.mesh @mesh0(shape = 1x1)
+shard.grid @grid0(shape = 1x1)
func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dynamic_operand_dimension(
%arg0 : tensor<3x?xi8>) -> tensor<?x3xi8> {
// expected-error at +1 {{Dimension size mismatch for result axis 1. Expected dynamic, but got 3.}}
- %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [1]
+ %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [1]
split_axis = 0 concat_axis = 1
: tensor<3x?xi8> -> tensor<?x3xi8>
return %0 : tensor<?x3xi8>
@@ -461,12 +461,12 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_induced_by_dyn
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
%arg0 : tensor<3x2xi8>) -> tensor<1x7xi8> {
// expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 6, but got 7.}}
- %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0]
split_axis = 0 concat_axis = 1
: tensor<3x2xi8> -> tensor<1x7xi8>
return %0 : tensor<1x7xi8>
@@ -474,12 +474,12 @@ func.func @all_to_all_invalid_non_dynamic_result_concat_dimension_size(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
%arg0 : tensor<3x2xi8>) -> tensor<2x6xi8> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
- %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0]
split_axis = 0 concat_axis = 1
: tensor<3x2xi8> -> tensor<2x6xi8>
return %0 : tensor<2x6xi8>
@@ -487,12 +487,12 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @broadcast_root_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
- %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0]
root = [3]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -500,12 +500,12 @@ func.func @broadcast_root_dimension_out_of_bounds(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @broadcast_root_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
- %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0]
root = [2, 2]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -513,12 +513,12 @@ func.func @broadcast_root_wrong_number_dimensions(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @broadcast_different_input_and_result_type(
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
- // expected-error at +1 {{'mesh.broadcast' op failed to verify that all of {input, result} have same element type}}
- %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ // expected-error at +1 {{'shard.broadcast' op failed to verify that all of {input, result} have same element type}}
+ %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0]
root = [2]
: (tensor<2xi8>) -> tensor<2xi16>
return %0 : tensor<2xi16>
@@ -526,84 +526,84 @@ func.func @broadcast_different_input_and_result_type(
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @gather_wrong_return_element_type(
%arg0 : tensor<1xf32>) -> tensor<1xi8> {
- // expected-error at +1 {{'mesh.gather' op failed to verify that all of {input, result} have same element type}}
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+ // expected-error at +1 {{'shard.gather' op failed to verify that all of {input, result} have same element type}}
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 root = [0]
: (tensor<1xf32>) -> tensor<1xi8>
return %0 : tensor<1xi8>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @gather_invalid_non_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
// expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0 root = [0]
: (tensor<3x4xf32>) -> tensor<3x5xf32>
return %0 : tensor<3x5xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1x2)
+shard.grid @grid0(shape = 1x2)
func.func @gather_invalid_gather_axis_dimension_size(
%arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
// expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 root = [0]
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [1] gather_axis = 1 root = [0]
: (tensor<3x4xf32>) -> tensor<3x5xf32>
return %0 : tensor<3x5xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @gather_invalid_gather_axis_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<3xf32> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
- %0 = mesh.gather %arg0 on @mesh0 gather_axis = 0 root = []
+ %0 = shard.gather %arg0 on @grid0 gather_axis = 0 root = []
: (tensor<?xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @gather_invalid_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
// expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 root = [0]
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 1 root = [0]
: (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
-mesh.mesh @mesh0(shape = 1)
+shard.grid @grid0(shape = 1)
func.func @gather_invalid_negative_gather_axis(
%arg0 : tensor<3xf32>) -> tensor<3xf32> {
// expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 root = [0]
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = -1 root = [0]
: (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
}
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @gather_root_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<6xi8> {
// expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
root = [3]
: (tensor<2xi8>) -> tensor<6xi8>
return %0 : tensor<6xi8>
@@ -611,12 +611,12 @@ func.func @gather_root_dimension_out_of_bounds(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @gather_root_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0] gather_axis = 0
root = [2, 2]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -624,12 +624,12 @@ func.func @gather_root_wrong_number_dimensions(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @receive_source_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error at +1 {{Out of bounds coordinate 0 for in-group device "source". Got 3, but expected value in the range [0, 2].}}
- %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.recv %arg0 on @grid0 grid_axes = [0]
source = [3]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -637,12 +637,12 @@ func.func @receive_source_dimension_out_of_bounds(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @receive_source_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error at +1 {{In-group device "source" has unexpected multi-index size 2. Expected 1.}}
- %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.recv %arg0 on @grid0 grid_axes = [0]
source = [2, 2]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -650,12 +650,12 @@ func.func @receive_source_wrong_number_dimensions(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @receive_different_input_and_result_type(
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
- // expected-error at +1 {{'mesh.recv' op failed to verify that all of {input, result} have same element type}}
- %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+ // expected-error at +1 {{'shard.recv' op failed to verify that all of {input, result} have same element type}}
+ %0 = shard.recv %arg0 on @grid0 grid_axes = [0]
source = [2]
: (tensor<2xi8>) -> tensor<2xi16>
return %0 : tensor<2xi16>
@@ -663,12 +663,12 @@ func.func @receive_different_input_and_result_type(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @reduce_root_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
- %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.reduce %arg0 on @grid0 grid_axes = [0]
root = [3]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -676,12 +676,12 @@ func.func @reduce_root_dimension_out_of_bounds(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @reduce_root_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
- %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.reduce %arg0 on @grid0 grid_axes = [0]
root = [2, 2]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -689,12 +689,12 @@ func.func @reduce_root_wrong_number_dimensions(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @reduce_different_input_and_result_shape(
%arg0 : tensor<2xi8>) -> tensor<3xi16> {
- // expected-error at +1 {{'mesh.reduce' op failed to verify that all of {input, result} have same shape}}
- %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+ // expected-error at +1 {{'shard.reduce' op failed to verify that all of {input, result} have same shape}}
+ %0 = shard.reduce %arg0 on @grid0 grid_axes = [0]
root = [2]
: (tensor<2xi8>) -> tensor<3xi16>
return %0 : tensor<3xi16>
@@ -702,60 +702,60 @@ func.func @reduce_different_input_and_result_shape(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
-func.func @reduce_scatter_duplicate_mesh_axis(
+func.func @reduce_scatter_duplicate_grid_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
- // expected-error at +1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0, 0] scatter_axis = 0
+ // expected-error at +1 {{Grid axes contains duplicate elements.}}
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0, 0] scatter_axis = 0
: tensor<?xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @reduce_scatter_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf64> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
- %0 = mesh.reduce_scatter %arg0 on @mesh0 scatter_axis = 0
+ %0 = shard.reduce_scatter %arg0 on @grid0 scatter_axis = 0
: tensor<?xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @reduce_scatter_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf64> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
- %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0
: tensor<3xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @reduce_scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf64> {
// expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
- %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [0] scatter_axis = 0
: tensor<4xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
-func.func @scatter_duplicate_mesh_axis(
+func.func @scatter_duplicate_grid_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
- // expected-error at +1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 0]
+ // expected-error at +1 {{Grid axes contains duplicate elements.}}
+ %0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 0]
scatter_axis = 0 root = [0, 0]
: (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
@@ -763,12 +763,12 @@ func.func @scatter_duplicate_mesh_axis(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @scatter_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf32> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
- %0 = mesh.scatter %arg0 on @mesh0
+ %0 = shard.scatter %arg0 on @grid0
scatter_axis = 0 root = []
: (tensor<?xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
@@ -776,12 +776,12 @@ func.func @scatter_invalid_dynamic_dimension(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @scatter_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf32> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
- %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
scatter_axis = 0 root = [1]
: (tensor<3xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
@@ -789,12 +789,12 @@ func.func @scatter_invalid_static_dimension_size(
// -----
-mesh.mesh @mesh0(shape = 3)
+shard.grid @grid0(shape = 3)
func.func @scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf32> {
// expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
- %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
scatter_axis = 0 root = [1]
: (tensor<4xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
@@ -802,12 +802,12 @@ func.func @scatter_invalid_operand_static_dimension_size(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @scatter_root_dimension_out_of_bounds(
%arg0 : tensor<3xi8>) -> tensor<1xi8> {
// expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
- %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
scatter_axis = 0 root = [3]
: (tensor<3xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
@@ -815,12 +815,12 @@ func.func @scatter_root_dimension_out_of_bounds(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @scatter_root_wrong_number_dimensions(
%arg0 : tensor<3xi8>) -> tensor<1xi8> {
// expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
- %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.scatter %arg0 on @grid0 grid_axes = [0]
scatter_axis = 0 root = [2, 2]
: (tensor<3xi8>) -> tensor<1xi8>
return %0 : tensor<1xi8>
@@ -828,12 +828,12 @@ func.func @scatter_root_wrong_number_dimensions(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @send_destination_dimension_out_of_bounds(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error at +1 {{Out of bounds coordinate 0 for in-group device "destination". Got 3, but expected value in the range [0, 2].}}
- %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.send %arg0 on @grid0 grid_axes = [0]
destination = [3]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -841,12 +841,12 @@ func.func @send_destination_dimension_out_of_bounds(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @send_destination_wrong_number_dimensions(
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// expected-error at +1 {{In-group device "destination" has unexpected multi-index size 2. Expected 1.}}
- %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.send %arg0 on @grid0 grid_axes = [0]
destination = [2, 2]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -854,12 +854,12 @@ func.func @send_destination_wrong_number_dimensions(
// -----
-mesh.mesh @mesh0(shape = 3x?)
+shard.grid @grid0(shape = 3x?)
func.func @send_different_input_and_result_type(
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
- // expected-error at +1 {{'mesh.send' op failed to verify that all of {input, result} have same element type}}
- %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+ // expected-error at +1 {{'shard.send' op failed to verify that all of {input, result} have same element type}}
+ %0 = shard.send %arg0 on @grid0 grid_axes = [0]
destination = [2]
: (tensor<2xi8>) -> tensor<2xi16>
return %0 : tensor<2xi16>
@@ -867,10 +867,10 @@ func.func @send_different_input_and_result_type(
// -----
-func.func @shift_invalid_mesh_symbol(
+func.func @shift_invalid_grid_symbol(
%arg0 : tensor<4xi8>) -> tensor<4xi8> {
- // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
- %0 = mesh.shift %arg0 on @this_mesh_symbol_does_not_exist
+ // expected-error at +1 {{Undefined required grid symbol "this_grid_symbol_does_not_exist".}}
+ %0 = shard.shift %arg0 on @this_grid_symbol_does_not_exist
shift_axis = 0 offset = -2
: tensor<4xi8> -> tensor<4xi8>
return %0 : tensor<4xi8>
@@ -878,12 +878,12 @@ func.func @shift_invalid_mesh_symbol(
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @shift_invalid_mesh_axis(
+func.func @shift_invalid_grid_axis(
%arg0 : tensor<4xi8>) -> tensor<4xi8> {
- // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
- %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [2]
+ // expected-error at +1 {{0-based grid axis index 2 is out of bounds. The referenced grid "grid0" is of rank 2.}}
+ %0 = shard.shift %arg0 on @grid0 grid_axes = [2]
shift_axis = 2 offset = -2
: tensor<4xi8> -> tensor<4xi8>
return %0 : tensor<4xi8>
@@ -891,12 +891,12 @@ func.func @shift_invalid_mesh_axis(
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
-func.func @shift_duplicate_mesh_axis(
+func.func @shift_duplicate_grid_axis(
%arg0 : tensor<4xi8>) -> tensor<4xi8> {
- // expected-error at +1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 1, 0]
+ // expected-error at +1 {{Grid axes contains duplicate elements.}}
+ %0 = shard.shift %arg0 on @grid0 grid_axes = [0, 1, 0]
shift_axis = 0 offset = -2
: tensor<4xi8> -> tensor<4xi8>
return %0 : tensor<4xi8>
@@ -904,12 +904,12 @@ func.func @shift_duplicate_mesh_axis(
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
func.func @shift_invalid_tensor_dimension_size(
%arg0 : tensor<4xi8>) -> tensor<5xi8> {
- // expected-error at +1 {{'mesh.shift' op requires the same shape for all operands and results}}
- %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0]
+ // expected-error at +1 {{'shard.shift' op requires the same shape for all operands and results}}
+ %0 = shard.shift %arg0 on @grid0 grid_axes = [0]
shift_axis = 0 offset = 2
: tensor<4xi8> -> tensor<5xi8>
return %0 : tensor<5xi8>
@@ -917,12 +917,12 @@ func.func @shift_invalid_tensor_dimension_size(
// -----
-mesh.mesh @mesh0(shape = 2x4)
+shard.grid @grid0(shape = 2x4)
func.func @shift_invalid_shift_axis(
%arg0 : tensor<4xi8>) -> tensor<4xi8> {
- // expected-error at +1 {{Invalid shift axis 1. It must be one of the grouping mesh axes.}}
- %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0]
+ // expected-error at +1 {{Invalid shift axis 1. It must be one of the grouping grid axes.}}
+ %0 = shard.shift %arg0 on @grid0 grid_axes = [0]
shift_axis = 1 offset = 2
: tensor<4xi8> -> tensor<4xi8>
return %0 : tensor<4xi8>
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Shard/ops.mlir
similarity index 55%
rename from mlir/test/Dialect/Mesh/ops.mlir
rename to mlir/test/Dialect/Shard/ops.mlir
index c354de514fba8..5265dadd2a845 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Shard/ops.mlir
@@ -1,176 +1,176 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// CHECK: mesh.mesh @mesh0
-mesh.mesh @mesh0(shape = 2x2x4)
+// CHECK: shard.grid @grid0
+shard.grid @grid0(shape = 2x2x4)
-// CHECK: mesh.mesh @mesh1(shape = 4x?)
-mesh.mesh @mesh1(shape = 4x?)
+// CHECK: shard.grid @grid1(shape = 4x?)
+shard.grid @grid1(shape = 4x?)
-// CHECK: mesh.mesh @mesh2(shape = ?x4)
-mesh.mesh @mesh2(shape = ?x4)
+// CHECK: shard.grid @grid2(shape = ?x4)
+shard.grid @grid2(shape = ?x4)
-// CHECK: mesh.mesh @mesh3(shape = ?x?)
-mesh.mesh @mesh3(shape = ?x?)
+// CHECK: shard.grid @grid3(shape = ?x?)
+shard.grid @grid3(shape = ?x?)
-mesh.mesh @mesh4(shape = 3)
+shard.grid @grid4(shape = 3)
-// CHECK: mesh.mesh @mesh5(shape = ?)
-mesh.mesh @mesh5(shape = ?)
+// CHECK: shard.grid @grid5(shape = ?)
+shard.grid @grid5(shape = ?)
-// CHECK-LABEL: func @mesh_shard_op_fully_replicated
+// CHECK-LABEL: func @grid_shard_op_fully_replicated
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
- %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
- // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+func.func @grid_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}]] : !shard.sharding
+ %s = shard.sharding @grid0 split_axes = [[]] : !shard.sharding
+ // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
-// CHECK-LABEL: func @mesh_shard_op_1st_dim
+// CHECK-LABEL: func @grid_shard_op_1st_dim
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding
- %s = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
+func.func @grid_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}0]] : !shard.sharding
+ %s = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
-// CHECK-LABEL: func @mesh_shard_op_2nd_dim
+// CHECK-LABEL: func @grid_shard_op_2nd_dim
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh1 split_axes = {{\[\[}}], [0]] : !mesh.sharding
- %s = mesh.sharding @mesh1 split_axes = [[], [0]] : !mesh.sharding
- // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
- %0 = mesh.shard %arg0 to %s : tensor<4x8xf32>
+func.func @grid_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid1 split_axes = {{\[\[}}], [0]] : !shard.sharding
+ %s = shard.sharding @grid1 split_axes = [[], [0]] : !shard.sharding
+ // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8xf32>
+ %0 = shard.shard %arg0 to %s : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
-// CHECK-LABEL: func @mesh_shard_op_1st_and_3rd_dim
-func.func @mesh_shard_op_1st_and_3rd_dim(
+// CHECK-LABEL: func @grid_shard_op_1st_and_3rd_dim
+func.func @grid_shard_op_1st_and_3rd_dim(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32>
%arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh3 split_axes = {{\[\[}}0], [], [1]] : !mesh.sharding
- %s = mesh.sharding @mesh3 split_axes = [[0], [], [1]] : !mesh.sharding
- // CHECK-NEXT: mesh.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32>
- %0 = mesh.shard %arg0 to %s : tensor<4x8x16xf32>
+ // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid3 split_axes = {{\[\[}}0], [], [1]] : !shard.sharding
+ %s = shard.sharding @grid3 split_axes = [[0], [], [1]] : !shard.sharding
+ // CHECK-NEXT: shard.shard %[[ARG]] to %[[S]] : tensor<4x8x16xf32>
+ %0 = shard.shard %arg0 to %s : tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}
-// CHECK-LABEL: func @mesh_shard_op_two_users
+// CHECK-LABEL: func @grid_shard_op_two_users
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
+func.func @grid_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
(tensor<4x8xf32>, tensor<4x8xf32>) {
- // CHECK-NEXT: %[[V0:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}0]] : !mesh.sharding
- %s0 = mesh.sharding @mesh0 split_axes = [[0]] : !mesh.sharding
- %0 = mesh.shard %arg0 to %s0 : tensor<4x8xf32>
- // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}1]] : !mesh.sharding
- %s1 = mesh.sharding @mesh0 split_axes = [[1]] : !mesh.sharding
- %1 = mesh.shard %0 to %s1 annotate_for_users : tensor<4x8xf32>
- // CHECK-DAG: mesh.sharding @mesh0 split_axes = {{\[\[}}2]] : !mesh.sharding
- %s2 = mesh.sharding @mesh0 split_axes = [[2]] : !mesh.sharding
- %2 = mesh.shard %0 to %s2 annotate_for_users : tensor<4x8xf32>
+ // CHECK-NEXT: %[[V0:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}0]] : !shard.sharding
+ %s0 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<4x8xf32>
+ // CHECK-DAG: shard.sharding @grid0 split_axes = {{\[\[}}1]] : !shard.sharding
+ %s1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<4x8xf32>
+ // CHECK-DAG: shard.sharding @grid0 split_axes = {{\[\[}}2]] : !shard.sharding
+ %s2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding
+ %2 = shard.shard %0 to %s2 annotate_for_users : tensor<4x8xf32>
return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
}
-// CHECK-LABEL: func @mesh_shard_halo_sizes
-func.func @mesh_shard_halo_sizes() -> () {
+// CHECK-LABEL: func @grid_shard_halo_sizes
+func.func @grid_shard_halo_sizes() -> () {
// CHECK: %[[C3:.*]] = arith.constant 3 : i64
%c3 = arith.constant 3 : i64
- // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [1, 4] : !mesh.sharding
- %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [1, 4] : !mesh.sharding
- // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !mesh.sharding
- %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [4, %c3] : !mesh.sharding
+ // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] halo_sizes = [1, 4] : !shard.sharding
+ %sharding1 = shard.sharding @grid4 split_axes = [[0]] halo_sizes = [1, 4] : !shard.sharding
+ // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] halo_sizes = [4, %[[C3]]] : !shard.sharding
+ %sharding2 = shard.sharding @grid4 split_axes = [[0]] halo_sizes = [4, %c3] : !shard.sharding
return
}
-// CHECK-LABEL: func @mesh_shard_dims_sizes
-func.func @mesh_shard_dims_sizes() -> () {
+// CHECK-LABEL: func @grid_shard_dims_sizes
+func.func @grid_shard_dims_sizes() -> () {
// CHECK: %[[C3:.*]] = arith.constant 3 : i64
%c3 = arith.constant 3 : i64
- // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding
- %sharding1 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !mesh.sharding
- // CHECK: mesh.sharding @mesh4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !mesh.sharding
- %sharding2 = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !mesh.sharding
+ // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 6] : !shard.sharding
+ %sharding1 = shard.sharding @grid4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 6] : !shard.sharding
+ // CHECK: shard.sharding @grid4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 2, %[[C3]], 5] : !shard.sharding
+ %sharding2 = shard.sharding @grid4 split_axes = [[0]] sharded_dims_offsets = [0, 2, %c3, 5] : !shard.sharding
return
}
-// CHECK-LABEL: func @mesh_shard_shape
-func.func @mesh_shard_shape() {
+// CHECK-LABEL: func @grid_shard_shape
+func.func @grid_shard_shape() {
// CHECK: %[[C3:.*]] = arith.constant 3 : index
%c3 = arith.constant 3 : index
- // CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
- %s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
- // CHECK-NEXT: mesh.shard_shape dims = [8, %[[C3]]
+ // CHECK-NEXT: %[[S:.*]] = shard.sharding @grid0 split_axes = {{\[\[}}]] : !shard.sharding
+ %s = shard.sharding @grid0 split_axes = [[]] : !shard.sharding
+ // CHECK-NEXT: shard.shard_shape dims = [8, %[[C3]]
// CHECK-SAME: ] sharding = %[[S]] device = [%[[C3]]
// CHECK-SAME: ] : index, index
- %shp:2 = mesh.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index
- // CHECK-NEXT: mesh.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index
- %shp1:2 = mesh.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index
+ %shp:2 = shard.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index
+ // CHECK-NEXT: shard.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index
+ %shp1:2 = shard.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index
return
}
-// CHECK-LABEL: func @mesh_get_sharding
+// CHECK-LABEL: func @grid_get_sharding
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
-func.func @mesh_get_sharding(%arg0 : tensor<4x8xf32>) -> !mesh.sharding {
- // CHECK-NEXT: mesh.get_sharding %[[ARG]] : tensor<4x8xf32> -> !mesh.sharding
- %0 = mesh.get_sharding %arg0 : tensor<4x8xf32> -> !mesh.sharding
- return %0 : !mesh.sharding
+func.func @grid_get_sharding(%arg0 : tensor<4x8xf32>) -> !shard.sharding {
+ // CHECK-NEXT: shard.get_sharding %[[ARG]] : tensor<4x8xf32> -> !shard.sharding
+ %0 = shard.get_sharding %arg0 : tensor<4x8xf32> -> !shard.sharding
+ return %0 : !shard.sharding
}
-// CHECK-LABEL: func @mesh_shape
-func.func @mesh_shape() -> (index, index) {
- // CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
- %0:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
+// CHECK-LABEL: func @grid_shape
+func.func @grid_shape() -> (index, index) {
+ // CHECK: %[[RES:.*]]:2 = shard.grid_shape @grid0 axes = [0, 1] : index, index
+ %0:2 = shard.grid_shape @grid0 axes = [0, 1] : index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
return %0#0, %0#1 : index, index
}
-// CHECK-LABEL: func @mesh_shape_default_axes
-func.func @mesh_shape_default_axes() -> (index, index, index) {
- // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
- %0:3 = mesh.mesh_shape @mesh0 : index, index, index
+// CHECK-LABEL: func @grid_shape_default_axes
+func.func @grid_shape_default_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = shard.grid_shape @grid0 : index, index, index
+ %0:3 = shard.grid_shape @grid0 : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
-// CHECK-LABEL: func @mesh_shape_empty_axes
-func.func @mesh_shape_empty_axes() -> (index, index, index) {
- // CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
- %0:3 = mesh.mesh_shape @mesh0 axes = [] : index, index, index
+// CHECK-LABEL: func @grid_shape_empty_axes
+func.func @grid_shape_empty_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = shard.grid_shape @grid0 : index, index, index
+ %0:3 = shard.grid_shape @grid0 axes = [] : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// CHECK-LABEL: func @process_multi_index
func.func @process_multi_index() -> (index, index) {
- // CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
- %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
+ // CHECK: %[[RES:.*]]:2 = shard.process_multi_index on @grid0 axes = [0, 1] : index, index
+ %0:2 = shard.process_multi_index on @grid0 axes = [0, 1] : index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
return %0#0, %0#1 : index, index
}
// CHECK-LABEL: func @process_multi_index_default_axes
func.func @process_multi_index_default_axes() -> (index, index, index) {
- // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ // CHECK: %[[RES:.*]]:3 = shard.process_multi_index on @grid0 : index, index, index
+ %0:3 = shard.process_multi_index on @grid0 : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// CHECK-LABEL: func @process_multi_index_empty_axes
func.func @process_multi_index_empty_axes() -> (index, index, index) {
- // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
- %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+ // CHECK: %[[RES:.*]]:3 = shard.process_multi_index on @grid0 : index, index, index
+ %0:3 = shard.process_multi_index on @grid0 axes = [] : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// CHECK-LABEL: func @process_linear_index
func.func @process_linear_index() -> index {
- // CHECK: %[[RES:.*]] = mesh.process_linear_index on @mesh0 : index
- %0 = mesh.process_linear_index on @mesh0 : index
+ // CHECK: %[[RES:.*]] = shard.process_linear_index on @grid0 : index
+ %0 = shard.process_linear_index on @grid0 : index
// CHECK: return %[[RES]] : index
return %0 : index
}
@@ -179,9 +179,9 @@ func.func @process_linear_index() -> index {
func.func @all_reduce(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
- // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = max
+ // CHECK-NEXT: shard.all_reduce %[[ARG]] on @grid0 grid_axes = [1, 0] reduction = max
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = max
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [1, 0] reduction = max
: tensor<3x4xf32> -> tensor<3x4xf64>
return %0 : tensor<3x4xf64>
}
@@ -190,9 +190,9 @@ func.func @all_reduce(
func.func @all_gather(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
- // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
+ // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid0 grid_axes = [2] gather_axis = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1
: tensor<3x4xf32> -> tensor<3x16xf32>
return %0 : tensor<3x16xf32>
}
@@ -201,20 +201,20 @@ func.func @all_gather(
func.func @all_gather_dynamic_dims_in_tensor(
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
- // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
+ // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid0 grid_axes = [2] gather_axis = 1
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
- %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
+ %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1
: tensor<?x?xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
-// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh
-func.func @all_gather_dynamic_dims_in_mesh(
+// CHECK-LABEL: func @all_gather_dynamic_dims_in_grid
+func.func @all_gather_dynamic_dims_in_grid(
// CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
%arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
- // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1
+ // CHECK-NEXT: shard.all_gather %[[ARG]] on @grid3 grid_axes = [1] gather_axis = 1
// CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
- %0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1
+ %0 = shard.all_gather %arg0 on @grid3 grid_axes = [1] gather_axis = 1
: tensor<5x6xf32> -> tensor<5x?xf32>
return %0 : tensor<5x?xf32>
}
@@ -223,10 +223,10 @@ func.func @all_gather_dynamic_dims_in_mesh(
func.func @all_slice_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
- // CHECK-NEXT: mesh.all_slice %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [2] slice_axis = 1
+ // CHECK-NEXT: shard.all_slice %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [2] slice_axis = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32>
- %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [2] slice_axis = 1
+ %0 = shard.all_slice %arg0 on @grid0 grid_axes = [2] slice_axis = 1
: tensor<3x4xf32> -> tensor<3x1xf32>
return %0 : tensor<3x1xf32>
}
@@ -235,10 +235,10 @@ func.func @all_slice_static_dimensions(
func.func @all_slice_dynamic_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
- // CHECK-NEXT: mesh.all_slice %[[ARG]]
- // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] slice_axis = 0
+ // CHECK-NEXT: shard.all_slice %[[ARG]]
+ // CHECK-SAME: on @grid3 grid_axes = [0, 1] slice_axis = 0
// CHECK-SAME: : tensor<?xf32> -> tensor<?xf32>
- %0 = mesh.all_slice %arg0 on @mesh3 mesh_axes = [0, 1] slice_axis = 0
+ %0 = shard.all_slice %arg0 on @grid3 grid_axes = [0, 1] slice_axis = 0
: tensor<?xf32> -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@@ -247,10 +247,10 @@ func.func @all_slice_dynamic_dimensions(
func.func @all_to_all(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
- // CHECK-NEXT: mesh.all_to_all %[[ARG]]
- // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
+ // CHECK-NEXT: shard.all_to_all %[[ARG]]
+ // CHECK-SAME: on @grid4 split_axis = 1 concat_axis = 0
// CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
- %0 = mesh.all_to_all %arg0 on @mesh4
+ %0 = shard.all_to_all %arg0 on @grid4
split_axis = 1 concat_axis = 0
: tensor<3x6xi8> -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
@@ -260,10 +260,10 @@ func.func @all_to_all(
func.func @all_to_all_dynamic_dims_in_result(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
- // CHECK-NEXT: mesh.all_to_all %[[ARG]]
- // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
+ // CHECK-NEXT: shard.all_to_all %[[ARG]]
+ // CHECK-SAME: on @grid4 split_axis = 1 concat_axis = 0
// CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
- %0 = mesh.all_to_all %arg0 on @mesh4
+ %0 = shard.all_to_all %arg0 on @grid4
split_axis = 1 concat_axis = 0
: tensor<3x6xi8> -> tensor<3x?xi8>
return %0 : tensor<3x?xi8>
@@ -273,10 +273,10 @@ func.func @all_to_all_dynamic_dims_in_result(
func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
// CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
%arg0 : tensor<3xi8>) -> tensor<3xi8> {
- // CHECK-NEXT: mesh.all_to_all %[[ARG]]
- // CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0
+ // CHECK-NEXT: shard.all_to_all %[[ARG]]
+ // CHECK-SAME: @grid4 split_axis = 0 concat_axis = 0
// CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
- %0 = mesh.all_to_all %arg0 on @mesh4
+ %0 = shard.all_to_all %arg0 on @grid4
split_axis = 0 concat_axis = 0
: tensor<3xi8> -> tensor<3xi8>
return %0 : tensor<3xi8>
@@ -286,10 +286,10 @@ func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
func.func @all_to_all_non_divisible_split_axis_size(
// CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8>
%arg0 : tensor<2x3xi8>) -> tensor<?x12xi8> {
- // CHECK-NEXT: mesh.all_to_all %[[ARG]]
- // CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1
+ // CHECK-NEXT: shard.all_to_all %[[ARG]]
+ // CHECK-SAME: @grid0 grid_axes = [0, 1] split_axis = 0 concat_axis = 1
// CHECK-SAME: : tensor<2x3xi8> -> tensor<?x12xi8>
- %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1]
+ %0 = shard.all_to_all %arg0 on @grid0 grid_axes = [0, 1]
split_axis = 0 concat_axis = 1
: tensor<2x3xi8> -> tensor<?x12xi8>
return %0 : tensor<?x12xi8>
@@ -299,11 +299,11 @@ func.func @all_to_all_non_divisible_split_axis_size(
func.func @broadcast_static_root(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
- // CHECK-NEXT: mesh.broadcast %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.broadcast %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: root = [0, 1]
// CHECK-SAME: : (tensor<3x6xi8>) -> tensor<3x6xi8>
- %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0, 2]
root = [0, 1]
: (tensor<3x6xi8>) -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
@@ -316,11 +316,11 @@ func.func @broadcast_dynamic_root(
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<3x6xi8> {
- // CHECK-NEXT: mesh.broadcast %[[ARG0]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.broadcast %[[ARG0]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: root = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<3x6xi8>
- %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.broadcast %arg0 on @grid0 grid_axes = [0, 2]
root = [1, %arg1]
: (tensor<3x6xi8>, index) -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
@@ -330,12 +330,12 @@ func.func @broadcast_dynamic_root(
func.func @gather_static_root(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>) -> tensor<24x6xi8> {
- // CHECK-NEXT: mesh.gather %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.gather %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: gather_axis = 0
// CHECK-SAME: root = [0, 1]
// CHECK-SAME: : (tensor<3x6xi8>) -> tensor<24x6xi8>
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0, 2]
gather_axis = 0
root = [0, 1]
: (tensor<3x6xi8>) -> tensor<24x6xi8>
@@ -349,12 +349,12 @@ func.func @gather_dynamic_root(
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<24x6xi8> {
- // CHECK-NEXT: mesh.gather %[[ARG0]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.gather %[[ARG0]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: gather_axis = 0
// CHECK-SAME: root = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<24x6xi8>
- %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.gather %arg0 on @grid0 grid_axes = [0, 2]
gather_axis = 0
root = [1, %arg1]
: (tensor<3x6xi8>, index) -> tensor<24x6xi8>
@@ -365,11 +365,11 @@ func.func @gather_dynamic_root(
func.func @receive_static_source(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.recv %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.recv %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: source = [0, 1]
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
- %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2]
source = [0, 1]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -382,11 +382,11 @@ func.func @receive_dynamic_source(
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.recv %[[ARG0]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.recv %[[ARG0]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: source = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
- %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2]
source = [1, %arg1]
: (tensor<2xi8>, index) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -396,9 +396,9 @@ func.func @receive_dynamic_source(
func.func @receive_no_source(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.recv %[[ARG]]
+ // CHECK-NEXT: shard.recv %[[ARG]]
// CHECK-NOT: source
- %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.recv %arg0 on @grid0 grid_axes = [0, 2]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
@@ -407,11 +407,11 @@ func.func @receive_no_source(
func.func @reduce_static_root(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.reduce %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.reduce %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: root = [0, 1]
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
- %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2]
root = [0, 1]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -424,11 +424,11 @@ func.func @reduce_dynamic_root(
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.reduce %[[ARG0]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.reduce %[[ARG0]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: root = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
- %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2]
root = [1, %arg1]
: (tensor<2xi8>, index) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -438,11 +438,11 @@ func.func @reduce_dynamic_root(
func.func @reduce_different_return_element_type(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
- // CHECK-NEXT: mesh.reduce %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.reduce %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: root = [0, 1]
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi16>
- %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.reduce %arg0 on @grid0 grid_axes = [0, 2]
root = [0, 1]
: (tensor<2xi8>) -> tensor<2xi16>
return %0 : tensor<2xi16>
@@ -452,10 +452,10 @@ func.func @reduce_different_return_element_type(
func.func @reduce_scatter_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
- // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = max scatter_axis = 1
+ // CHECK-NEXT: shard.reduce_scatter %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [2] reduction = max scatter_axis = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
- %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2]
+ %0 = shard.reduce_scatter %arg0 on @grid0 grid_axes = [2]
reduction = max scatter_axis = 1
: tensor<3x4xf32> -> tensor<3x1xf64>
return %0 : tensor<3x1xf64>
@@ -465,10 +465,10 @@ func.func @reduce_scatter_static_dimensions(
func.func @reduce_scatter_dynamic_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
- // CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
- // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+ // CHECK-NEXT: shard.reduce_scatter %[[ARG]]
+ // CHECK-SAME: on @grid3 grid_axes = [0, 1] scatter_axis = 0
// CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
- %0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+ %0 = shard.reduce_scatter %arg0 on @grid3 grid_axes = [0, 1] scatter_axis = 0
: tensor<?xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
@@ -477,11 +477,11 @@ func.func @reduce_scatter_dynamic_dimensions(
func.func @scatter_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
- // CHECK-NEXT: mesh.scatter %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [2]
+ // CHECK-NEXT: shard.scatter %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [2]
// CHECK-SAME: scatter_axis = 1 root = [1]
// CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32>
- %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [2]
+ %0 = shard.scatter %arg0 on @grid0 grid_axes = [2]
scatter_axis = 1 root = [1]
: (tensor<3x4xf32>) -> tensor<3x1xf32>
return %0 : tensor<3x1xf32>
@@ -491,11 +491,11 @@ func.func @scatter_static_dimensions(
func.func @scatter_dynamic_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
- // CHECK-NEXT: mesh.scatter %[[ARG]]
- // CHECK-SAME: on @mesh3 mesh_axes = [0, 1]
+ // CHECK-NEXT: shard.scatter %[[ARG]]
+ // CHECK-SAME: on @grid3 grid_axes = [0, 1]
// CHECK-SAME: scatter_axis = 0 root = [1, 2]
// CHECK-SAME: : (tensor<?xf32>) -> tensor<?xf32>
- %0 = mesh.scatter %arg0 on @mesh3 mesh_axes = [0, 1]
+ %0 = shard.scatter %arg0 on @grid3 grid_axes = [0, 1]
scatter_axis = 0 root = [1, 2]
: (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
@@ -508,12 +508,12 @@ func.func @scatter_dynamic_root(
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<1xi8> {
- // CHECK-NEXT: mesh.scatter %[[ARG0]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.scatter %[[ARG0]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: scatter_axis = 0
// CHECK-SAME: root = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8>
- %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.scatter %arg0 on @grid0 grid_axes = [0, 2]
scatter_axis = 0
root = [1, %arg1]
: (tensor<8xi8>, index) -> tensor<1xi8>
@@ -524,11 +524,11 @@ func.func @scatter_dynamic_root(
func.func @send_static_destination(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.send %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.send %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: destination = [0, 1]
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
- %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.send %arg0 on @grid0 grid_axes = [0, 2]
destination = [0, 1]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -541,11 +541,11 @@ func.func @send_dynamic_destination(
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.send %[[ARG0]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.send %[[ARG0]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: destination = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
- %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.send %arg0 on @grid0 grid_axes = [0, 2]
destination = [1, %arg1]
: (tensor<2xi8>, index) -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -555,11 +555,11 @@ func.func @send_dynamic_destination(
func.func @shift(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
- // CHECK-NEXT: mesh.shift %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-NEXT: shard.shift %[[ARG]]
+ // CHECK-SAME: on @grid0 grid_axes = [0, 2]
// CHECK-SAME: shift_axis = 2 offset = -2 rotate
// CHECK-SAME: : tensor<2xi8> -> tensor<2xi8>
- %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 2]
+ %0 = shard.shift %arg0 on @grid0 grid_axes = [0, 2]
shift_axis = 2 offset = -2 rotate
: tensor<2xi8> -> tensor<2xi8>
return %0 : tensor<2xi8>
@@ -570,16 +570,16 @@ func.func @update_halo(
// CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
%arg0 : memref<12x12xi8>) {
// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
- // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] on @mesh0
+ // CHECK-NEXT: %[[UH1:.*]] = shard.update_halo %[[ARG]] on @grid0
// CHECK-SAME: split_axes = {{\[\[}}0]]
// CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
%c2 = arith.constant 2 : i64
- %uh1 = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
+ %uh1 = shard.update_halo %arg0 on @grid0 split_axes = [[0]]
halo_sizes = [2, %c2] : memref<12x12xi8>
- // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[UH1]] on @mesh0
+ // CHECK-NEXT: %[[UH2:.*]] = shard.update_halo %[[UH1]] on @grid0
// CHECK-SAME: split_axes = {{\[\[}}0], [1]]
// CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8>
- %uh2 = mesh.update_halo %uh1 on @mesh0 split_axes = [[0], [1]]
+ %uh2 = shard.update_halo %uh1 on @grid0 split_axes = [[0], [1]]
halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8>
return
}
diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir
new file mode 100644
index 0000000000000..467dfa9ef0fab
--- /dev/null
+++ b/mlir/test/Dialect/Shard/partition.mlir
@@ -0,0 +1,317 @@
+// RUN: mlir-opt \
+// RUN: --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \
+// RUN: %s | FileCheck %s
+
+shard.grid @grid_1d(shape = 2)
+
+// CHECK-LABEL: func @return_sharding
+func.func @return_sharding(
+ // CHECK-SAME: [[ARG:%.*]]: tensor<1xf32>
+ %arg0: tensor<2xf32>
+// CHECK-SAME: ) -> (tensor<1xf32>, !shard.sharding) {
+) -> (tensor<2xf32>, !shard.sharding) {
+ %ssharding_annotated = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %sharding_annotated = shard.shard %arg0 to %ssharding_annotated : tensor<2xf32>
+ // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}0]] : !shard.sharding
+ %r = shard.get_sharding %sharding_annotated : tensor<2xf32> -> !shard.sharding
+ // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !shard.sharding
+ return %sharding_annotated, %r : tensor<2xf32>, !shard.sharding
+}
+
+// CHECK-LABEL: func @full_replication
+func.func @full_replication(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<2xi8> {
+) -> tensor<2xi8> {
+ %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<2xi8>
+ %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8>
+ // CHECK: return %[[ARG]] : tensor<2xi8>
+ return %1 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @sharding_triplet
+func.func @sharding_triplet(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32>
+ %arg0: tensor<2xf32>
+// CHECK-SAME: ) -> tensor<2xf32> {
+) -> tensor<2xf32> {
+ // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[ARG]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32>
+ %ssharding_annotated = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %sharding_annotated = shard.shard %arg0 to %ssharding_annotated : tensor<2xf32>
+ %ssharding_annotated_0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %sharding_annotated_0 = shard.shard %sharding_annotated to %ssharding_annotated_0 annotate_for_users : tensor<2xf32>
+ %ssharding_annotated_1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %sharding_annotated_1 = shard.shard %sharding_annotated_0 to %ssharding_annotated_1 : tensor<2xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<2xf32>
+ return %sharding_annotated_1 : tensor<2xf32>
+}
+
+
+// CHECK-LABEL: func @move_split_axis
+func.func @move_split_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8>
+ %arg0: tensor<2x2xi8>
+// CHECK-SAME: -> tensor<2x1xi8> {
+) -> tensor<2x2xi8> {
+ // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[ARG]] on @grid_1d
+ // CHECK-SAME: grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<1x2xi8> -> tensor<2x1xi8>
+ %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<2x2xi8>
+ %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2x2xi8>
+ // CHECK: return %[[ALL_TO_ALL]] : tensor<2x1xi8>
+ return %1 : tensor<2x2xi8>
+}
+
+// CHECK-LABEL: func @non_tensor_value
+func.func @non_tensor_value(
+ // CHECK-SAME: %[[ARG:.*]]: i8
+ %arg0: i8
+// CHECK-SAME: -> i8 {
+) -> i8 {
+ // CHECK: %[[RES:.*]] = arith.addi %[[ARG]], %[[ARG]] : i8
+ %0 = arith.addi %arg0, %arg0 : i8
+ // CHECK: return %[[RES]] : i8
+ return %0 : i8
+}
+
+// CHECK-LABEL: func @unary_elementwise
+func.func @unary_elementwise(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<1xi8>
+ %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+ %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<2xi8>
+ %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8>
+ // CHECK: %[[RES:.*]] = tosa.abs %[[ARG]] : (tensor<1xi8>) -> tensor<1xi8>
+ %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+ %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %3 = shard.shard %2 to %s3 : tensor<2xi8>
+ %s4 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %4 = shard.shard %3 to %s4 annotate_for_users : tensor<2xi8>
+ // CHECK: return %[[RES]] : tensor<1xi8>
+ return %4 : tensor<2xi8>
+}
+
+// full replication -> shard axis -> abs -> shard axis -> full replication
+// CHECK-LABEL: func @unary_elementwise_with_resharding
+func.func @unary_elementwise_with_resharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<2xi8> {
+) -> tensor<2xi8> {
+ // CHECK: %[[SLICE:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 0
+ // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
+ %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<2xi8>
+ %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8>
+ // CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
+ %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+ // CHECK: %[[RES:.*]] = shard.all_gather %[[ABS]] on @grid_1d
+ // CHECK-SAME: grid_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+ %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %3 = shard.shard %2 to %s3 : tensor<2xi8>
+ %s4 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %4 = shard.shard %3 to %s4 annotate_for_users : tensor<2xi8>
+ // CHECK: return %[[RES]] : tensor<2xi8>
+ return %4 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @binary_elementwise
+func.func @binary_elementwise(
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<1xi8>,
+ %arg0: tensor<2xi8>,
+ // CHECK-SAME: %[[ARG1:.*]]: tensor<1xi8>
+ %arg1: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+ %sarg0_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %arg0_sharded = shard.shard %arg0 to %sarg0_sharded : tensor<2xi8>
+ %sop_arg0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %op_arg0 = shard.shard %arg0_sharded to %sop_arg0 annotate_for_users : tensor<2xi8>
+ %sarg1_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %arg1_sharded = shard.shard %arg1 to %sarg1_sharded : tensor<2xi8>
+ %sop_arg1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %op_arg1 = shard.shard %arg1_sharded to %sop_arg1 annotate_for_users : tensor<2xi8>
+ // CHECK: %[[RES:.*]] = tosa.add %[[ARG0]], %[[ARG1]] : (tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
+ %op_res = tosa.add %op_arg0, %op_arg1 : (tensor<2xi8>, tensor<2xi8>) -> tensor<2xi8>
+ %sop_res_sharded = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %op_res_sharded = shard.shard %op_res to %sop_res_sharded : tensor<2xi8>
+ %sres = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %res = shard.shard %op_res_sharded to %sres annotate_for_users : tensor<2xi8>
+ // CHECK: return %[[RES]] : tensor<1xi8>
+ return %res : tensor<2xi8>
+}
+
+// reshard
+// abs
+// reshard
+// abs
+// reshard
+// CHECK-LABEL: func @multiple_chained_ops
+func.func @multiple_chained_ops(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+ // CHECK: %[[RESHARD1:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 0
+ // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
+ %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<2xi8>
+ %s1 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xi8>
+ // CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
+ %2 = tosa.abs %1 : (tensor<2xi8>) -> tensor<2xi8>
+ // CHECK: %[[RESHARD2:.*]] = shard.all_gather %[[ABS1]] on @grid_1d
+ // CHECK-SAME: grid_axes = [0] gather_axis = 0 : tensor<1xi8> -> tensor<2xi8>
+ %s3 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %3 = shard.shard %2 to %s3 : tensor<2xi8>
+ %s4 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %4 = shard.shard %3 to %s4 annotate_for_users : tensor<2xi8>
+ // CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
+ %5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
+ // CHECK: %[[RESHARD3:.*]] = shard.all_slice %[[ABS2]] on @grid_1d grid_axes = [0] slice_axis = 0 :
+ // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
+ %s6 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %6 = shard.shard %5 to %s6 : tensor<2xi8>
+ %s7 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %7 = shard.shard %6 to %s7 annotate_for_users : tensor<2xi8>
+ // CHECK: return %[[RESHARD3]] : tensor<1xi8>
+ return %7 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @incomplete_sharding
+func.func @incomplete_sharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
+ %arg0: tensor<8x16xf32>
+// CHECK-SAME: -> tensor<4x16xf32> {
+) -> tensor<8x16xf32> {
+ %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32>
+ // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %s2 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %2 = shard.shard %1 to %s2 : tensor<8x16xf32>
+ // CHECK: return %[[RES]] : tensor<4x16xf32>
+ return %2 : tensor<8x16xf32>
+}
+
+shard.grid @grid_1d_4(shape = 4)
+
+// CHECK-LABEL: func @ew_chain_with_halo
+func.func @ew_chain_with_halo(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<5x16xf32>
+ %arg0: tensor<8x16xf32>,
+ // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xf32>
+ %arg1: tensor<1xf32>,
+ // CHECK-SAME: %[[IN3:[A-Za-z0-9_]+]]: tensor<1xf32>
+ %arg2: tensor<1xf32>)
+ // CHECK-SAME: -> tensor<5x16xf32>
+ -> tensor<8x16xf32> {
+ %ssharding_annotated = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+ %sharding_annotated = shard.shard %arg0 to %ssharding_annotated annotate_for_users : tensor<8x16xf32>
+ // CHECK: %[[TMP1:.*]] = tosa.tanh %[[IN1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
+ %0 = tosa.tanh %sharding_annotated : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %ssharding_annotated_0 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+ %sharding_annotated_0 = shard.shard %0 to %ssharding_annotated_0 : tensor<8x16xf32>
+ %ssharding_annotated_1 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+ %sharding_annotated_1 = shard.shard %sharding_annotated_0 to %ssharding_annotated_1 annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[TMP2:.*]] = tosa.abs %[[TMP1]] : (tensor<5x16xf32>) -> tensor<5x16xf32>
+ %1 = tosa.abs %sharding_annotated_1 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %ssharding_annotated_2 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+ %sharding_annotated_2 = shard.shard %1 to %ssharding_annotated_2 : tensor<8x16xf32>
+ %ssharding_annotated_4 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+ %sharding_annotated_4 = shard.shard %sharding_annotated_2 to %ssharding_annotated_4 annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[TMP3:.*]] = tosa.negate %[[TMP2]], %[[IN2]], %[[IN3]] : (tensor<5x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x16xf32>
+ %sharding_1 = shard.sharding @grid_1d_4 split_axes = [[]] : !shard.sharding
+ %zero_point_1 = shard.shard %arg1 to %sharding_1 annotate_for_users : tensor<1xf32>
+ %zero_point_2 = shard.shard %arg2 to %sharding_1 annotate_for_users : tensor<1xf32>
+ %2 = tosa.negate %sharding_annotated_4, %zero_point_1, %zero_point_2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
+ %ssharding_annotated_5 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+ %sharding_annotated_5 = shard.shard %2 to %ssharding_annotated_5 : tensor<8x16xf32>
+ %ssharding_annotated_6 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 1] : !shard.sharding
+ %sharding_annotated_6 = shard.shard %sharding_annotated_5 to %ssharding_annotated_6 annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[TMP3]] : tensor<5x16xf32>
+ return %sharding_annotated_6 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func @test_shard_update_halo
+// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x1200xi64>
+func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
+ %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] : !shard.sharding
+ // CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
+ // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
+ // CHECK: %[[UH:.*]] = shard.update_halo %[[inserted_slice]] on @grid_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64>
+ %sharding_annotated = shard.shard %arg0 to %sharding : tensor<1200x1200xi64>
+ %sharding_0 = shard.sharding @grid_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !shard.sharding
+ %sharding_annotated_1 = shard.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
+ %sharding_annotated_3 = shard.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
+ // CHECK: return %[[UH]] : tensor<304x1200xi64>
+ return %sharding_annotated_3 : tensor<1200x1200xi64>
+}
+
+shard.grid @grid4x4(shape = 4x4)
+// CHECK-LABEL: func @test_shard_update_halo2d
+// CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<300x300xi64>
+func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1200xi64> {
+ %sharding = shard.sharding @grid4x4 split_axes = [[0], [1]] : !shard.sharding
+ // CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
+ // CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
+ // CHECK: %[[UH:.*]] = shard.update_halo %[[inserted_slice]] on @grid4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64>
+ %sharding_annotated = shard.shard %arg0 to %sharding : tensor<1200x1200xi64>
+ %sharding_0 = shard.sharding @grid4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !shard.sharding
+ %sharding_annotated_1 = shard.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
+ %sharding_annotated_3 = shard.shard %sharding_annotated_1 to %sharding_0 annotate_for_users : tensor<1200x1200xi64>
+ // CHECK: return %[[UH]] : tensor<303x307xi64>
+ return %sharding_annotated_3 : tensor<1200x1200xi64>
+}
+
+shard.grid @grid(shape = 2)
+// CHECK-LABEL: func.func @test_reduce_0d(
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32>
+func.func @test_reduce_0d(%arg0: tensor<6x6xi32>) -> (tensor<i32>) {
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %sharded = shard.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
+ %4 = tensor.empty() : tensor<i32>
+ %sharding_out = shard.sharding @grid split_axes = [[]] : !shard.sharding
+ %sharded_out = shard.shard %4 to %sharding_out : tensor<i32>
+ %sharded_in = shard.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
+ // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
+ %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<i32>) dimensions = [0, 1]
+ (%in: i32, %init: i32) {
+ %6 = arith.addi %in, %init : i32
+ linalg.yield %6 : i32
+ }
+ // CHECK: %[[all_reduce:.*]] = shard.all_reduce %[[reduced]] on @grid grid_axes = [0] : tensor<i32> -> tensor<i32>
+ %sharded_red = shard.shard %reduced to %sharding_out : tensor<i32>
+ %sharded_ret = shard.shard %sharded_red to %sharding_out annotate_for_users : tensor<i32>
+ // CHECK: return %[[all_reduce]] : tensor<i32>
+ return %sharded_ret : tensor<i32>
+}
+
+// CHECK-LABEL: func.func @test_reduce_1d(
+// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<3x6xi32>
+func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) {
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %sharded = shard.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
+ %4 = tensor.empty() : tensor<6xi32>
+ %sharded_out = shard.shard %4 to %sharding : tensor<6xi32>
+ %sharded_in = shard.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
+ // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
+ %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<6xi32>) dimensions = [1]
+ (%in: i32, %init: i32) {
+ %6 = arith.addi %in, %init : i32
+ linalg.yield %6 : i32
+ }
+ // CHECK-NOT: shard.all_reduce
+ %sharded_red = shard.shard %reduced to %sharding : tensor<6xi32>
+ %sharded_ret = shard.shard %sharded_red to %sharding annotate_for_users : tensor<6xi32>
+ // CHECK: return %[[reduced]] : tensor<3xi32>
+ return %sharded_ret : tensor<6xi32>
+}
diff --git a/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir b/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir
new file mode 100644
index 0000000000000..33c7a8f96464d
--- /dev/null
+++ b/mlir/test/Dialect/Shard/process-multi-index-op-lowering.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -test-grid-process-multi-index-op-lowering %s | FileCheck %s
+
+shard.grid @grid2d(shape = ?x?)
+
+// CHECK-LABEL: func.func @multi_index_2d_grid
+func.func @multi_index_2d_grid() -> (index, index) {
+ // CHECK: %[[LINEAR_IDX:.*]] = shard.process_linear_index on @grid2d : index
+ // CHECK: %[[SHARD_SHAPE:.*]]:2 = shard.grid_shape @grid2d : index, index
+ // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[SHARD_SHAPE]]#0, %[[SHARD_SHAPE]]#1) : index, index
+ %0:2 = shard.process_multi_index on @grid2d : index, index
+ // CHECK: return %[[MULTI_IDX]]#0, %[[MULTI_IDX]]#1 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func.func @multi_index_2d_grid_single_inner_axis
+func.func @multi_index_2d_grid_single_inner_axis() -> index {
+ // CHECK: %[[LINEAR_IDX:.*]] = shard.process_linear_index on @grid2d : index
+ // CHECK: %[[SHARD_SHAPE:.*]]:2 = shard.grid_shape @grid2d : index, index
+ // CHECK: %[[MULTI_IDX:.*]]:2 = affine.delinearize_index %[[LINEAR_IDX]] into (%[[SHARD_SHAPE]]#0, %[[SHARD_SHAPE]]#1) : index, index
+ %0 = shard.process_multi_index on @grid2d axes = [0] : index
+ // CHECK: return %[[MULTI_IDX]]#0 : index
+ return %0 : index
+}
diff --git a/mlir/test/Dialect/Shard/resharding-partition.mlir b/mlir/test/Dialect/Shard/resharding-partition.mlir
new file mode 100644
index 0000000000000..ff9e8408aa7fd
--- /dev/null
+++ b/mlir/test/Dialect/Shard/resharding-partition.mlir
@@ -0,0 +1,168 @@
+// RUN: mlir-opt -test-grid-resharding-partition %s | FileCheck %s
+
+shard.grid @grid_1d(shape = 2)
+shard.grid @grid_1d_dynamic(shape = ?)
+
+// CHECK-LABEL: func @same_source_and_target_sharding
+func.func @same_source_and_target_sharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
+ %arg0: tensor<2xf32>
+) -> tensor<2xf32> {
+ %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<2xf32>
+ %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<2xf32>
+ // CHECK: return %[[ARG]]
+ return %1 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @identical_source_and_target_sharding
+func.func @identical_source_and_target_sharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
+ %arg0: tensor<2xf32>
+) -> tensor<2xf32> {
+ %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<2xf32>
+ %1 = shard.shard %0 to %s0 annotate_for_users : tensor<2xf32>
+ // CHECK: return %[[ARG]]
+ return %1 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis
+func.func @split_replicated_tensor_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
+ %arg0: tensor<3x14xf32>
+) -> tensor<3x14xf32> {
+ // CHECK: %[[ALL_SLICE:.*]] = shard.all_slice %[[ARG]] on @grid_1d grid_axes = [0] slice_axis = 1
+ // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
+ // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
+ %s0 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<3x14xf32>
+ %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<3x14xf32>
+ // CHECK: return %[[RESULT]] : tensor<3x14xf32>
+ return %1 : tensor<3x14xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic
+func.func @split_replicated_tensor_axis_dynamic(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
+ %arg0: tensor<?x3x?xf32>
+) -> tensor<?x3x?xf32> {
+ // CHECK: %[[RESULT:.*]] = shard.all_slice %[[ARG]] on @grid_1d_dynamic grid_axes = [0] slice_axis = 0
+ // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
+ %s0 = shard.sharding @grid_1d_dynamic split_axes = [[], [], []] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<?x3x?xf32>
+ %s1 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<?x3x?xf32>
+ // CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
+ return %1 : tensor<?x3x?xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis
+func.func @move_split_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+ // CHECK: %[[TARGET_SHARD:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
+ %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis_dynamic_grid
+func.func @move_split_axis_dynamic_grid(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+ // CHECK: %[[ALL_TO_ALL:.*]] = shard.all_to_all %[[SOURCE_SHARD]] on @grid_1d_dynamic grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
+ // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
+ %s0 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = shard.sharding @grid_1d_dynamic split_axes = [[], [0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_dynamic_axis
+func.func @move_split_dynamic_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+ %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+ // CHECK: %[[TARGET_SHARD:.*]] = shard.all_to_all %[[ARG]] on @grid_1d grid_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
+ %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<?x14xf32>
+ %s1 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
+ // CHECK: return %[[RES]] : tensor<?x14xf32>
+ return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis
+func.func @unshard_static_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+ // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
+ %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_last_axis
+func.func @unshard_static_last_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<10x7xf32>
+ // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d grid_axes = [0] gather_axis = 1 : tensor<10x7xf32> -> tensor<10x14xf32>
+ %s0 = shard.sharding @grid_1d split_axes = [[], [0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = shard.sharding @grid_1d split_axes = [[], []] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_dynamic_axis
+func.func @unshard_dynamic_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+ %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+ // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[ARG]] on @grid_1d grid_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+ %s0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<?x14xf32>
+ %s1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<?x14xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
+ return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis_on_dynamic_grid_axis
+func.func @unshard_static_axis_on_dynamic_grid_axis(
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+ // CHECK: %[[ALL_GATHER:.*]] = shard.all_gather %[[SOURCE_SHARD]] on @grid_1d_dynamic grid_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+ // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
+ %s0 = shard.sharding @grid_1d_dynamic split_axes = [[0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<10x14xf32>
+ %s1 = shard.sharding @grid_1d_dynamic split_axes = [[]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation-failed.mlir b/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
similarity index 100%
rename from mlir/test/Dialect/Mesh/sharding-propagation-failed.mlir
rename to mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
diff --git a/mlir/test/Dialect/Shard/sharding-propagation.mlir b/mlir/test/Dialect/Shard/sharding-propagation.mlir
new file mode 100644
index 0000000000000..34aaf0598b3f0
--- /dev/null
+++ b/mlir/test/Dialect/Shard/sharding-propagation.mlir
@@ -0,0 +1,301 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s
+
+shard.grid @grid_2(shape = 2)
+shard.grid @grid_1d(shape = ?)
+shard.grid @grid_2d(shape = 2x4)
+shard.grid @grid_3d(shape = ?x?x?)
+
+// CHECK-LABEL: func.func @element_wise_empty_sharding_info
+func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: tosa.sigmoid
+ %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: return
+ return %0 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_def
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+ // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
+ %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
+ %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V2]]
+ return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_use
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+ // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ %s0 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V2]]
+ return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_graph_output
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+ // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
+ %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V3]]
+ return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_graph_input
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+ // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[V0]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ %s0 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.sigmoid %[[V1]]
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V3]]
+ return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @arrow_structure
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @arrow_structure(%arg0: tensor<8x16xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
+ // CHECK-NEXT: %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+ // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[ARG]] to %[[S1]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.tanh %[[V1]]
+ // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S1]] : tensor<8x16xf32>
+ %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V4:.*]] = shard.shard %[[V3]] to %[[S1]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]]
+ // CHECK-NEXT: %[[V6:.*]] = shard.shard %[[V5]] to %[[S1]] : tensor<8x16xf32>
+ %1 = tosa.abs %0: (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[S3:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK-NEXT: %[[ZP1:.*]] = shard.shard %arg1 to %[[S3]] annotate_for_users : tensor<1xf32>
+ // CHECK-NEXT: %[[ZP2:.*]] = shard.shard %arg2 to %[[S3]] annotate_for_users : tensor<1xf32>
+ // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]], %[[ZP1]], %[[ZP2]]
+ // CHECK-NEXT: %[[V8:.*]] = shard.shard %[[V7]] to %[[S1]] : tensor<8x16xf32>
+ %2 = tosa.negate %0, %arg1, %arg2 : (tensor<8x16xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<8x16xf32>
+ %s3 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+ %3 = shard.shard %2 to %s3 : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V6]], %[[V8]]
+ return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
+func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0], [1]] : !shard.sharding
+ // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
+ // CHECK-NEXT: %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}0]] : !shard.sharding
+ // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
+ // CHECK-NEXT: %[[S2:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK-NEXT: %[[ZP:.*]] = shard.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
+ %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] : tensor<2x16x32xf32>
+ %s1 = shard.sharding @grid_2d split_axes = [[0], [1]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 : tensor<2x16x32xf32>
+ // CHECK-NEXT: return %[[V3]]
+ return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_n
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32>
+func.func @matmul_on_def_shard_m_and_n(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
+ // CHECK: [[vsharding:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding
+ // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32>
+ // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [], [1]] : !shard.sharding
+ // CHECK: [[vsharded_1:%.*]] = shard.shard [[varg1]] to [[vsharding_0]] annotate_for_users : tensor<2x8x32xf32>
+ // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK: [[vsharded_3:%.*]] = shard.shard [[varg2]] to [[vsharding_2]] annotate_for_users : tensor<1xf32>
+ // CHECK: [[v0:%.*]] = tosa.matmul
+ %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
+ // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0], [1]] : !shard.sharding
+ // CHECK: [[vsharded_5:%.*]] = shard.shard [[v0]] to [[vsharding_4]] : tensor<2x16x32xf32>
+ %s1 = shard.sharding @grid_2d split_axes = [[], [0], [1]] : !shard.sharding
+ %1 = shard.shard %0 to %s1 : tensor<2x16x32xf32>
+ // CHECK-NEXT: return [[vsharded_5]]
+ return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x16x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<1xf32>
+func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
+ // CHECK: [[vsharding:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0], [1]] : !shard.sharding
+ %s0 = shard.sharding @grid_2d split_axes = [[], [0], [1]] : !shard.sharding
+ // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x16x8xf32>
+ %arg0_s = shard.shard %arg0 to %s0 : tensor<2x16x8xf32>
+ // CHECK: [[vsharded_0:%.*]] = shard.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x16x8xf32>
+ // CHECK: [[vsharding_1:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1]] : !shard.sharding
+ // CHECK: [[vsharded_2:%.*]] = shard.shard [[varg1]] to [[vsharding_1]] annotate_for_users : tensor<2x8x32xf32>
+ // CHECK: [[vsharding_3:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK: [[vsharded_4:%.*]] = shard.shard [[varg2]] to [[vsharding_3]] annotate_for_users : tensor<1xf32>
+ // CHECK: [[v0:%.*]] = tosa.matmul
+ // CHECK: [[vsharding_5:%.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding
+ // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding_5]] : tensor<2x16x32xf32>
+ %0 = tosa.matmul %arg0_s, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
+ // CHECK: return [[vsharded_6]]
+ return %0 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
+func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1], [0]] : !shard.sharding
+ // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
+ %s0 = shard.sharding @grid_2d split_axes = [[], [1], [0]] : !shard.sharding
+ %0 = shard.shard %arg0 to %s0 annotate_for_users : tensor<2x16x8xf32>
+ // CHECK-NEXT: %[[S1:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [0]] : !shard.sharding
+ // CHECK-NEXT: %[[V1:.*]] = shard.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
+ %s1 = shard.sharding @grid_2d split_axes = [[], [0]] : !shard.sharding
+ %1 = shard.shard %arg1 to %s1 annotate_for_users : tensor<2x8x32xf32>
+ // CHECK-NEXT: %[[S2:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK-NEXT: %[[ZP:.*]] = shard.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
+ %2 = tosa.matmul %0, %1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[S3:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}], [1]] : !shard.sharding
+ // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S3]] : tensor<2x16x32xf32>
+ // CHECK-NEXT: return %[[V3]]
+ return %2 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @resolve_conflicting_annotations
+func.func @resolve_conflicting_annotations(
+ // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>,
+ %arg0: tensor<2x3xf32>,
+ // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>,
+ %arg1: tensor<3x2xf32>,
+ // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32>
+ %out_dps: tensor<2x2xf32>
+// CHECK-SAME: ) -> tensor<2x2xf32> {
+) -> tensor<2x2xf32> {
+ // CHECK: %[[SIN1_SHARDED1:.*]] = shard.sharding @grid_2 split_axes = {{\[\[}}0]] : !shard.sharding
+ // CHECK-NEXT: %[[IN1_SHARDED1:.*]] = shard.shard %[[IN1]] to %[[SIN1_SHARDED1]] : tensor<2x3xf32>
+ // CHECK: %[[SIN2_SHARDED:.*]] = shard.sharding @grid_2 split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK-NEXT: %[[IN1_SHARDED2:.*]] = shard.shard %[[IN1_SHARDED1]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x3xf32>
+ // CHECK-NEXT: %[[IN2_SHARDED:.*]] = shard.shard %[[IN2]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<3x2xf32>
+ // CHECK-NEXT: %[[OUT_DPS_SHARDED:.*]] = shard.shard %[[OUT_DPS]] to %[[SIN2_SHARDED]] annotate_for_users : tensor<2x2xf32>
+ %sarg0_sharded = shard.sharding @grid_2 split_axes = [[0]] : !shard.sharding
+ %arg0_sharded = shard.shard %arg0 to %sarg0_sharded : tensor<2x3xf32>
+ // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+ // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+ %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+ outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+ // CHECK-NEXT: %[[RES:.*]] = shard.shard %[[MATMUL]] to %[[SIN2_SHARDED]] : tensor<2x2xf32>
+ %sres_sharded = shard.sharding @grid_2 split_axes = [[]] : !shard.sharding
+ %res_sharded = shard.shard %res to %sres_sharded : tensor<2x2xf32>
+ // CHECK: return %[[RES]] : tensor<2x2xf32>
+ return %res_sharded : tensor<2x2xf32>
+}
+
+// https://arxiv.org/abs/2211.05102 Figure 2(a)
+// The sharding propagation results in unnecessary reshards,
+// an optimization pass should be able to remove them.
+// CHECK-LABEL: func.func @mlp_1d_weight_stationary
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32>
+func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
+ %s0 = shard.sharding @grid_1d split_axes = [[], [], [0, 1, 2]] : !shard.sharding
+ %sharded0 = shard.shard %arg0 to %s0 : tensor<2x4x8xf32>
+ %sharded1 = shard.shard %arg1 to %s0 : tensor<2x8x32xf32>
+ // CHECK: [[vsharding:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}], [], [0, 1, 2]] : !shard.sharding
+ // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32>
+ // CHECK: [[vsharded_0:%.*]] = shard.shard [[varg1]] to [[vsharding]] : tensor<2x8x32xf32>
+ // CHECK: [[vsharded_1:%.*]] = shard.shard [[vsharded]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32>
+ // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}], [0, 1, 2]] : !shard.sharding
+ // CHECK: [[vsharded_3:%.*]] = shard.shard [[vsharded_0]] to [[vsharding_2]] annotate_for_users : tensor<2x8x32xf32>
+ // CHECK: [[vsharding_4:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK: [[vsharded_5:%.*]] = shard.shard [[varg3]] to [[vsharding_4]] annotate_for_users : tensor<1xf32>
+ // CHECK: [[v0:%.*]] = tosa.matmul
+ %1 = tosa.matmul %sharded0, %sharded1, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32>
+ // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding_4]] : tensor<2x4x32xf32>
+ // CHECK: [[vsharded_7:%.*]] = shard.shard [[vsharded_6]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32>
+ // CHECK: [[v1:%.*]] = tosa.sigmoid [[vsharded_7]] : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+ // CHECK: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding_4]] : tensor<2x4x32xf32>
+ %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+ %sharding = shard.sharding @grid_1d split_axes = [[], [0, 1, 2]] : !shard.sharding
+ // CHECK: [[vsharded_9:%.*]] = shard.shard [[varg2]] to [[vsharding_2]] : tensor<2x32x8xf32>
+ %sharded2 = shard.shard %arg2 to %sharding : tensor<2x32x8xf32>
+ // CHECK: [[vsharded_10:%.*]] = shard.shard [[vsharded_8]] to [[vsharding_4]] annotate_for_users : tensor<2x4x32xf32>
+ // CHECK: [[vsharded_11:%.*]] = shard.shard [[vsharded_9]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32>
+ // CHECK: [[v2:%.*]] = tosa.matmul
+ %3 = tosa.matmul %2, %sharded2, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32>
+ // CHECK: [[vsharded_12:%.*]] = shard.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32>
+ %s4 = shard.sharding @grid_1d split_axes = [[], [], [0, 1, 2]] : !shard.sharding
+ %4 = shard.shard %3 to %s4 : tensor<2x4x8xf32>
+ // CHECK: return [[vsharded_12]]
+ return %4 : tensor<2x4x8xf32>
+}
+
+// https://arxiv.org/abs/2211.05102 Figure 2(b)
+// The sharding propagation results in unnecessary reshards,
+// an optimization pass should be able to remove them.
+// CHECK-LABEL: func.func @mlp_2d_weight_stationary
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x4x8xf32>, [[varg1:%.*]]: tensor<2x8x32xf32>, [[varg2:%.*]]: tensor<2x32x8xf32>, [[varg3:%.*]]: tensor<1xf32>
+func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>, %arg3: tensor<1xf32>) -> tensor<2x4x8xf32> {
+ // CHECK: [[vsharding:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [], [0, 1, 2]] : !shard.sharding
+ %s0 = shard.sharding @grid_3d split_axes = [[], [], [0, 1, 2]] : !shard.sharding
+ // CHECK: [[vsharded:%.*]] = shard.shard [[varg0]] to [[vsharding]] : tensor<2x4x8xf32>
+ %arg0_s = shard.shard %arg0 to %s0 : tensor<2x4x8xf32>
+ // CHECK: [[vsharding_0:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [0], [1, 2]] : !shard.sharding
+ %s1 = shard.sharding @grid_3d split_axes = [[], [0], [1, 2]] : !shard.sharding
+ // CHECK: [[vsharded_1:%.*]] = shard.shard [[varg1]] to [[vsharding_0]] : tensor<2x8x32xf32>
+ %arg1_s = shard.shard %arg1 to %s1 : tensor<2x8x32xf32>
+ // CHECK: [[vsharding_2:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK: [[vsharded_3:%.*]] = shard.shard [[vsharded]] to [[vsharding_2]] annotate_for_users : tensor<2x4x8xf32>
+ // CHECK: [[vsharded_4:%.*]] = shard.shard [[vsharded_1]] to [[vsharding]] annotate_for_users : tensor<2x8x32xf32>
+ // CHECK: [[vsharded_5:%.*]] = shard.shard [[varg3]] to [[vsharding_2]] annotate_for_users : tensor<1xf32>
+ // CHECK: [[v0:%.*]] = tosa.matmul
+ %1 = tosa.matmul %arg0_s, %arg1_s, %arg3, %arg3 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x32xf32>
+ // CHECK: [[vsharded_6:%.*]] = shard.shard [[v0]] to [[vsharding]] : tensor<2x4x32xf32>
+ %2 = shard.shard %1 to %s0 : tensor<2x4x32xf32>
+ // CHECK: [[vsharded_7:%.*]] = shard.shard [[vsharded_6]] to [[vsharding]] annotate_for_users : tensor<2x4x32xf32>
+ // CHECK: [[v1:%.*]] = tosa.sigmoid
+ // CHECK: [[vsharded_8:%.*]] = shard.shard [[v1]] to [[vsharding]] : tensor<2x4x32xf32>
+ %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+ // CHECK: [[vsharding_9:%.*]] = shard.sharding @grid_3d split_axes = {{\[\[}}], [1, 2], [0]] : !shard.sharding
+ %s2 = shard.sharding @grid_3d split_axes = [[], [1, 2], [0]] : !shard.sharding
+ // CHECK: [[vsharded_10:%.*]] = shard.shard [[varg2]] to [[vsharding_9]] : tensor<2x32x8xf32>
+ %arg2_s = shard.shard %arg2 to %s2 : tensor<2x32x8xf32>
+ // CHECK: [[vsharded_11:%.*]] = shard.shard [[vsharded_8]] to [[vsharding_2]] annotate_for_users : tensor<2x4x32xf32>
+ // CHECK: [[vsharded_12:%.*]] = shard.shard [[vsharded_10]] to [[vsharding]] annotate_for_users : tensor<2x32x8xf32>
+ // CHECK: [[v2:%.*]] = tosa.matmul
+ %4 = tosa.matmul %3, %arg2_s, %arg3, %arg3 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x4x8xf32>
+ // CHECK: [[vsharded_13:%.*]] = shard.shard [[v2]] to [[vsharding]] : tensor<2x4x8xf32>
+ %5 = shard.shard %4 to %s0 : tensor<2x4x8xf32>
+ // CHECK: [[vsharded_14:%.*]] = shard.shard [[vsharded_13]] to [[vsharding]] annotate_for_users : tensor<2x4x8xf32>
+ %6 = shard.shard %5 to %s0 annotate_for_users : tensor<2x4x8xf32>
+ // CHECK: return [[vsharded_14]]
+ return %6 : tensor<2x4x8xf32>
+}
+
+// CHECK-LABEL: func.func @elementwise_duplicated_chain
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @elementwise_duplicated_chain(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[S0:.*]] = shard.sharding @grid_2d split_axes = {{\[\[}}]] : !shard.sharding
+ // CHECK-NEXT: %[[V0:.*]] = shard.shard %[[ARG]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
+ %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = shard.shard %[[V1]] to %[[S0]] : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V3:.*]] = shard.shard %[[V2]] to %[[S0]] annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V4:.*]] = tosa.sigmoid %[[V3]]
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V5:.*]] = shard.shard %[[V4]] to %[[S0]] : tensor<8x16xf32>
+ %s0 = shard.sharding @grid_2d split_axes = [[]] : !shard.sharding
+ %2 = shard.shard %1 to %s0 : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V5]]
+ return %2 : tensor<8x16xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Shard/simplifications.mlir
similarity index 69%
rename from mlir/test/Dialect/Mesh/simplifications.mlir
rename to mlir/test/Dialect/Shard/simplifications.mlir
index e955f4c134259..33cd490be744a 100644
--- a/mlir/test/Dialect/Mesh/simplifications.mlir
+++ b/mlir/test/Dialect/Shard/simplifications.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
+// RUN: mlir-opt -test-grid-simplifications %s | FileCheck %s
-mesh.mesh @mesh0(shape = 4x2)
-mesh.mesh @mesh1(shape = 4)
+shard.grid @grid0(shape = 4x2)
+shard.grid @grid1(shape = 4)
// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
// `all_reduce(x + y)`.
@@ -11,13 +11,13 @@ func.func @all_reduce_arith_addf_endomorphism(
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
- // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]]
// CHECK: return %[[ALL_REDUCE_RES]]
return %2 : tensor<5xf32>
}
@@ -28,13 +28,13 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result(
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
- // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]]
// CHECK: return %[[ALL_REDUCE_RES]], %[[ALL_REDUCE_RES]]
return %2, %2 : tensor<5xf32>, tensor<5xf32>
}
@@ -46,11 +46,11 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
- // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]]
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]]
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
- // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]]
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]]
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE_0_RES]], %[[ALL_REDUCE_1_RES]]
%2 = arith.addf %0, %1 : tensor<5xf32>
@@ -58,17 +58,17 @@ func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result
return %0, %2 : tensor<5xf32>, tensor<5xf32>
}
-// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh
-func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_grid
+func.func @all_reduce_arith_addf_no_endomorphism_different_grid(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
- // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
- // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1
- %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid1
+ %1 = shard.all_reduce %arg1 on @grid1 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
@@ -76,17 +76,17 @@ func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
return %2 : tensor<5xf32>
}
-// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes
-func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes(
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_grid_axes
+func.func @all_reduce_arith_addf_no_endomorphism_different_grid_axes(
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
- // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0]
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
- // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1]
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1]
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [1]
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [1]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
@@ -100,11 +100,11 @@ func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
- // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = max
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = max
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0] reduction = max
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = max
: tensor<5xf32> -> tensor<5xf32>
- // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [0]
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
%2 = arith.addf %0, %1 : tensor<5xf32>
@@ -118,11 +118,11 @@ func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_elemen
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf64> {
- // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG0]] on @grid0 grid_axes = [0]
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf64>
- // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = shard.all_reduce %[[ARG1]] on @grid0 grid_axes = [0]
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0]
: tensor<5xf32> -> tensor<5xf64>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
%2 = arith.addf %0, %1 : tensor<5xf64>
@@ -138,13 +138,13 @@ func.func @all_reduce_arith_minimumf_endomorphism(
%arg0: tensor<5xf32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
%arg1: tensor<5xf32>) -> tensor<5xf32> {
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = min
: tensor<5xf32> -> tensor<5xf32>
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] reduction = min
: tensor<5xf32> -> tensor<5xf32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
%2 = arith.minimumf %0, %1 : tensor<5xf32>
- // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] on @grid0 grid_axes = [0] reduction = min
// CHECK: return %[[ALL_REDUCE_RES]]
return %2 : tensor<5xf32>
}
@@ -155,13 +155,13 @@ func.func @all_reduce_arith_minsi_endomorphism(
%arg0: tensor<5xi32>,
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
%arg1: tensor<5xi32>) -> tensor<5xi32> {
- %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = min
+ %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0] reduction = min
: tensor<5xi32> -> tensor<5xi32>
- %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = min
+ %1 = shard.all_reduce %arg1 on @grid0 grid_axes = [0] reduction = min
: tensor<5xi32> -> tensor<5xi32>
// CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
%2 = arith.minsi %0, %1 : tensor<5xi32>
- // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = min
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = shard.all_reduce %[[ADD_RES]] on @grid0 grid_axes = [0] reduction = min
// CHECK: return %[[ALL_REDUCE_RES]]
return %2 : tensor<5xi32>
}
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
deleted file mode 100644
index 8598d81ff6cfa..0000000000000
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ /dev/null
@@ -1,52 +0,0 @@
-// RUN: mlir-opt \
-// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-single-fold))" \
-// RUN: %s | FileCheck %s
-
-mesh.mesh @mesh_1d_4(shape = 4)
-
-// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets
-func.func @tensor_empty_static_sharded_dims_offsets() -> () {
- %b = tensor.empty() : tensor<8x16xf32>
- %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
- %sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
- // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
- // CHECK: %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index
- // CHECK: %[[V0:.*]]:2 = mesh.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
- // CHECK-SAME: ] : index, index
- // CHECK: tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
-
- return
-}
-
-// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets
-// CHECK-SAME: %[[A0:.*]]: index
-func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
- %b = tensor.empty(%arg0) : tensor<8x?xf32>
- %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
- %sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
- // CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
- // CHECK: %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index
- // CHECK: %[[V0:.*]]:2 = mesh.shard_shape dims = [8, %[[A0]]
- // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
- // CHECK-SAME: ] : index, index
- // CHECK: tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
-
- return
-}
-
-// CHECK-LABEL: func @tensor_empty_same_static_dims_sizes
-func.func @tensor_empty_same_static_dims_sizes() -> () {
- %b = tensor.empty() : tensor<16x16xf32>
- %sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !mesh.sharding
- %sharded= mesh.shard %b to %sharding : tensor<16x16xf32>
- // CHECK-NEXT: tensor.empty() : tensor<4x16xf32>
-
- return
-}
-
-// CHECK-LABEL: func @tensor_empty_0d
-func.func @tensor_empty_0d() -> () {
- tensor.empty() : tensor<f32>
- // CHECK-NEXT: tensor.empty() : tensor<f32>
- return
-}
diff --git a/mlir/test/Dialect/Tensor/shard-partition.mlir b/mlir/test/Dialect/Tensor/shard-partition.mlir
new file mode 100644
index 0000000000000..5918ee1eddf57
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/shard-partition.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-opt \
+// RUN: --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \
+// RUN: %s | FileCheck %s
+
+shard.grid @grid_1d_4(shape = 4)
+
+// CHECK-LABEL: func @tensor_empty_static_sharded_dims_offsets
+func.func @tensor_empty_static_sharded_dims_offsets() -> () {
+ %b = tensor.empty() : tensor<8x16xf32>
+ %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding
+ %sharded= shard.shard %b to %sharding : tensor<8x16xf32>
+ // CHECK: %[[sharding:.*]] = shard.sharding @grid_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding
+ // CHECK: %[[proc_multi_idx:.*]] = shard.process_multi_index on @grid_1d_4 : index
+ // CHECK: %[[V0:.*]]:2 = shard.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
+ // CHECK-SAME: ] : index, index
+ // CHECK: tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @tensor_empty_dynamic_sharded_dims_offsets
+// CHECK-SAME: %[[A0:.*]]: index
+func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
+ %b = tensor.empty(%arg0) : tensor<8x?xf32>
+ %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding
+ %sharded= shard.shard %b to %sharding : tensor<8x?xf32>
+ // CHECK: %[[sharding:.*]] = shard.sharding @grid_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !shard.sharding
+ // CHECK: %[[proc_multi_idx:.*]] = shard.process_multi_index on @grid_1d_4 : index
+ // CHECK: %[[V0:.*]]:2 = shard.shard_shape dims = [8, %[[A0]]
+ // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
+ // CHECK-SAME: ] : index, index
+ // CHECK: tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @tensor_empty_same_static_dims_sizes
+func.func @tensor_empty_same_static_dims_sizes() -> () {
+ %b = tensor.empty() : tensor<16x16xf32>
+ %sharding = shard.sharding @grid_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 4, 8, 12, 16] : !shard.sharding
+ %sharded= shard.shard %b to %sharding : tensor<16x16xf32>
+ // CHECK-NEXT: tensor.empty() : tensor<4x16xf32>
+
+ return
+}
+
+// CHECK-LABEL: func @tensor_empty_0d
+func.func @tensor_empty_0d() -> () {
+ tensor.empty() : tensor<f32>
+ // CHECK-NEXT: tensor.empty() : tensor<f32>
+ return
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index eb2f74e8aeca1..3b7bd9b9637a8 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -10,7 +10,7 @@ add_subdirectory(Linalg)
add_subdirectory(LLVM)
add_subdirectory(Math)
add_subdirectory(MemRef)
-add_subdirectory(Mesh)
+add_subdirectory(Shard)
add_subdirectory(NVGPU)
add_subdirectory(SCF)
add_subdirectory(Shape)
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Shard/CMakeLists.txt
similarity index 51%
rename from mlir/test/lib/Dialect/Mesh/CMakeLists.txt
rename to mlir/test/lib/Dialect/Shard/CMakeLists.txt
index 7bd0493d11a7e..f91c54721e030 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Shard/CMakeLists.txt
@@ -1,14 +1,14 @@
# Exclude tests from libMLIR.so
-add_mlir_library(MLIRMeshTest
+add_mlir_library(MLIRShardTest
TestOpLowering.cpp
- TestReshardingSpmdization.cpp
+ TestReshardingPartition.cpp
TestSimplifications.cpp
EXCLUDE_FROM_LIBMLIR
)
-mlir_target_link_libraries(MLIRMeshTest PUBLIC
- MLIRMeshDialect
- MLIRMeshTransforms
+mlir_target_link_libraries(MLIRShardTest PUBLIC
+ MLIRShardDialect
+ MLIRShardTransforms
MLIRPass
MLIRRewrite
MLIRTransformUtils
diff --git a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp b/mlir/test/lib/Dialect/Shard/TestOpLowering.cpp
similarity index 80%
rename from mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
rename to mlir/test/lib/Dialect/Shard/TestOpLowering.cpp
index dbae93b380f2b..43f3b3f239181 100644
--- a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
+++ b/mlir/test/lib/Dialect/Shard/TestOpLowering.cpp
@@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Shard/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
@@ -24,17 +24,17 @@ struct TestAllSliceOpLoweringPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
SymbolTableCollection symbolTableCollection;
- mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
+ shard::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
LogicalResult status =
applyPatternsGreedily(getOperation(), std::move(patterns));
(void)status;
assert(succeeded(status) && "applyPatternsGreedily failed.");
}
void getDependentDialects(DialectRegistry ®istry) const override {
- mesh::registerAllSliceOpLoweringDialects(registry);
+ shard::registerAllSliceOpLoweringDialects(registry);
}
StringRef getArgument() const final {
- return "test-mesh-all-slice-op-lowering";
+ return "test-grid-all-slice-op-lowering";
}
StringRef getDescription() const final {
return "Test lowering of all-slice.";
@@ -48,21 +48,21 @@ struct TestMultiIndexOpLoweringPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
SymbolTableCollection symbolTableCollection;
- mesh::populateProcessMultiIndexOpLoweringPatterns(patterns,
- symbolTableCollection);
+ shard::populateProcessMultiIndexOpLoweringPatterns(patterns,
+ symbolTableCollection);
LogicalResult status =
applyPatternsGreedily(getOperation(), std::move(patterns));
(void)status;
assert(succeeded(status) && "applyPatternsGreedily failed.");
}
void getDependentDialects(DialectRegistry ®istry) const override {
- mesh::registerProcessMultiIndexOpLoweringDialects(registry);
+ shard::registerProcessMultiIndexOpLoweringDialects(registry);
}
StringRef getArgument() const final {
- return "test-mesh-process-multi-index-op-lowering";
+ return "test-grid-process-multi-index-op-lowering";
}
StringRef getDescription() const final {
- return "Test lowering of mesh.process_multi_index op.";
+ return "Test lowering of shard.process_multi_index op.";
}
};
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp
similarity index 75%
rename from mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
rename to mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp
index 102e64de4bd1f..ac71ff60fc509 100644
--- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
+++ b/mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp
@@ -7,9 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/Shard/IR/ShardOps.h"
+#include "mlir/Dialect/Shard/Transforms/Partition.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -22,11 +22,11 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
-using namespace mlir::mesh;
+using namespace mlir::shard;
namespace {
-struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
+struct TestReshardingRewritePattern : OpRewritePattern<ShardOp> {
using OpRewritePattern<ShardOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ShardOp op,
@@ -36,18 +36,18 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
}
SymbolTableCollection symbolTable;
- mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
- op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr());
+ shard::GridOp grid = symbolTable.lookupNearestSymbolFrom<shard::GridOp>(
+ op, cast<ShardingOp>(op.getSharding().getDefiningOp()).getGridAttr());
bool foundUser = false;
for (auto user : op->getUsers()) {
if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
if (targetShardOp.getAnnotateForUsers() &&
- mesh == symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
+ grid == symbolTable.lookupNearestSymbolFrom<shard::GridOp>(
targetShardOp,
cast<ShardingOp>(
targetShardOp.getSharding().getDefiningOp())
- .getMeshAttr())) {
+ .getGridAttr())) {
foundUser = true;
break;
}
@@ -61,22 +61,22 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
for (auto user : op->getUsers()) {
auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
- symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(
+ symbolTable.lookupNearestSymbolFrom<shard::GridOp>(
targetShardOp,
cast<ShardingOp>(targetShardOp.getSharding().getDefiningOp())
- .getMeshAttr()) != mesh) {
+ .getGridAttr()) != grid) {
continue;
}
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
ShapedType sourceShardShape =
- shardShapedType(op.getResult().getType(), mesh, op.getSharding());
+ shardShapedType(op.getResult().getType(), grid, op.getSharding());
TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
builder
.create<UnrealizedConversionCastOp>(sourceShardShape, op.getSrc())
->getResult(0));
TypedValue<ShapedType> targetShard =
- reshard(builder, mesh, op, targetShardOp, sourceShard);
+ reshard(builder, grid, op, targetShardOp, sourceShard);
Value newTargetUnsharded =
builder
.create<UnrealizedConversionCastOp>(
@@ -90,13 +90,13 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
}
};
-struct TestMeshReshardingPass
- : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass)
+struct TestReshardingPass
+ : public PassWrapper<TestReshardingPass, OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReshardingPass)
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
- patterns.insert<TestMeshReshardingRewritePattern>(&getContext());
+ patterns.insert<TestReshardingRewritePattern>(&getContext());
if (failed(applyPatternsGreedily(getOperation().getOperation(),
std::move(patterns)))) {
return signalPassFailure();
@@ -107,18 +107,18 @@ struct TestMeshReshardingPass
registry.insert<BuiltinDialect>();
}
StringRef getArgument() const final {
- return "test-mesh-resharding-spmdization";
+ return "test-grid-resharding-partition";
}
StringRef getDescription() const final {
- return "Test Mesh dialect resharding spmdization.";
+ return "Test Shard dialect resharding partition.";
}
};
} // namespace
namespace mlir {
namespace test {
-void registerTestMeshReshardingSpmdizationPass() {
- PassRegistration<TestMeshReshardingPass>();
+void registerTestReshardingPartitionPass() {
+ PassRegistration<TestReshardingPass>();
}
} // namespace test
} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp
similarity index 60%
rename from mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
rename to mlir/test/lib/Dialect/Shard/TestSimplifications.cpp
index 01e196d29f7a5..28852153f37f6 100644
--- a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
+++ b/mlir/test/lib/Dialect/Shard/TestSimplifications.cpp
@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -16,23 +16,23 @@
using namespace mlir;
namespace {
-struct TestMeshSimplificationsPass
- : public PassWrapper<TestMeshSimplificationsPass, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshSimplificationsPass)
+struct TestShardSimplificationsPass
+ : public PassWrapper<TestShardSimplificationsPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestShardSimplificationsPass)
void runOnOperation() override;
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<arith::ArithDialect, mesh::MeshDialect>();
+ registry.insert<arith::ArithDialect, shard::ShardDialect>();
}
- StringRef getArgument() const final { return "test-mesh-simplifications"; }
- StringRef getDescription() const final { return "Test mesh simplifications"; }
+ StringRef getArgument() const final { return "test-grid-simplifications"; }
+ StringRef getDescription() const final { return "Test grid simplifications"; }
};
} // namespace
-void TestMeshSimplificationsPass::runOnOperation() {
+void TestShardSimplificationsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
SymbolTableCollection symbolTableCollection;
- mesh::populateSimplificationPatterns(patterns, symbolTableCollection);
+ shard::populateSimplificationPatterns(patterns, symbolTableCollection);
[[maybe_unused]] LogicalResult status =
applyPatternsGreedily(getOperation(), std::move(patterns));
assert(succeeded(status) && "Rewrite patters application did not converge.");
@@ -40,8 +40,8 @@ void TestMeshSimplificationsPass::runOnOperation() {
namespace mlir {
namespace test {
-void registerTestMeshSimplificationsPass() {
- PassRegistration<TestMeshSimplificationsPass>();
+void registerTestShardSimplificationsPass() {
+ PassRegistration<TestShardSimplificationsPass>();
}
} // namespace test
} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 26d7597347a8a..6958fe3001b89 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -29,7 +29,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestMathToVCIX
MLIRMemRefTestPasses
MLIRTestMemRefToLLVMWithTransforms
- MLIRMeshTest
+ MLIRShardTest
MLIRNVGPUTestPasses
MLIRSCFTestPasses
MLIRShapeTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 143a5e8e8f8dd..2c0975302e6a5 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -130,8 +130,8 @@ void registerTestIrdlTestDialectConversionPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
void registerTestMemRefToLLVMWithTransforms();
-void registerTestMeshReshardingSpmdizationPass();
-void registerTestMeshSimplificationsPass();
+void registerTestReshardingPartitionPass();
+void registerTestShardSimplificationsPass();
void registerTestMultiBuffering();
void registerTestNextAccessPass();
void registerTestNVGPULowerings();
@@ -276,8 +276,8 @@ void registerTestPasses() {
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
mlir::test::registerTestMemRefToLLVMWithTransforms();
- mlir::test::registerTestMeshReshardingSpmdizationPass();
- mlir::test::registerTestMeshSimplificationsPass();
+ mlir::test::registerTestReshardingPartitionPass();
+ mlir::test::registerTestShardSimplificationsPass();
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestNextAccessPass();
mlir::test::registerTestNVGPULowerings();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 9ec7c51da4065..6c528f8f7b6bd 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3056,14 +3056,14 @@ cc_library(
)
##---------------------------------------------------------------------------##
-# Mesh Dialect
+# Shard Dialect
##---------------------------------------------------------------------------##
td_library(
- name = "MeshTdFiles",
+ name = "ShardTdFiles",
srcs = [
- "include/mlir/Dialect/Mesh/IR/MeshBase.td",
- "include/mlir/Dialect/Mesh/IR/MeshOps.td",
+ "include/mlir/Dialect/Shard/IR/ShardBase.td",
+ "include/mlir/Dialect/Shard/IR/ShardOps.td",
],
includes = ["include"],
deps = [
@@ -3075,92 +3075,92 @@ td_library(
)
gentbl_cc_library(
- name = "MeshIncGen",
+ name = "ShardIncGen",
tbl_outs = {
- "include/mlir/Dialect/Mesh/IR/MeshOps.h.inc": [
+ "include/mlir/Dialect/Shard/IR/ShardOps.h.inc": [
"-gen-op-decls",
- "-dialect=mesh",
+ "-dialect=shard",
],
- "include/mlir/Dialect/Mesh/IR/MeshOps.cpp.inc": [
+ "include/mlir/Dialect/Shard/IR/ShardOps.cpp.inc": [
"-gen-op-defs",
- "-dialect=mesh",
+ "-dialect=shard",
],
- "include/mlir/Dialect/Mesh/IR/MeshDialect.h.inc": [
+ "include/mlir/Dialect/Shard/IR/ShardDialect.h.inc": [
"-gen-dialect-decls",
- "-dialect=mesh",
+ "-dialect=shard",
],
- "include/mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc": [
+ "include/mlir/Dialect/Shard/IR/ShardDialect.cpp.inc": [
"-gen-dialect-defs",
- "-dialect=mesh",
+ "-dialect=shard",
],
- "include/mlir/Dialect/Mesh/IR/MeshEnums.h.inc": [
+ "include/mlir/Dialect/Shard/IR/ShardEnums.h.inc": [
"-gen-enum-decls",
- "-dialect=mesh",
+ "-dialect=shard",
],
- "include/mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc": [
+ "include/mlir/Dialect/Shard/IR/ShardEnums.cpp.inc": [
"-gen-enum-defs",
- "-dialect=mesh",
+ "-dialect=shard",
],
- "include/mlir/Dialect/Mesh/IR/MeshAttributes.h.inc": [
+ "include/mlir/Dialect/Shard/IR/ShardAttributes.h.inc": [
"-gen-attrdef-decls",
- "-dialect=mesh",
+ "-dialect=shard",
],
- "include/mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc": [
+ "include/mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc": [
"-gen-attrdef-defs",
- "-dialect=mesh",
+ "-dialect=shard",
],
- "include/mlir/Dialect/Mesh/IR/MeshTypes.h.inc": [
+ "include/mlir/Dialect/Shard/IR/ShardTypes.h.inc": [
"-gen-typedef-decls",
- "-typedefs-dialect=mesh",
+ "-typedefs-dialect=shard",
],
- "include/mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc": [
+ "include/mlir/Dialect/Shard/IR/ShardTypes.cpp.inc": [
"-gen-typedef-defs",
- "-typedefs-dialect=mesh",
+ "-typedefs-dialect=shard",
],
},
tblgen = ":mlir-tblgen",
- td_file = "include/mlir/Dialect/Mesh/IR/MeshOps.td",
+ td_file = "include/mlir/Dialect/Shard/IR/ShardOps.td",
deps = [
- ":MeshTdFiles",
+ ":ShardTdFiles",
":ShapeOpsTdFiles",
],
)
gentbl_cc_library(
- name = "MeshShardingInterfaceIncGen",
+ name = "ShardingInterfaceIncGen",
tbl_outs = {
- "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc": ["-gen-op-interface-decls"],
- "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc": ["-gen-op-interface-defs"],
+ "include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h.inc": ["-gen-op-interface-decls"],
+ "include/mlir/Dialect/Shard/Interfaces/ShardingInterface.cpp.inc": ["-gen-op-interface-defs"],
},
tblgen = ":mlir-tblgen",
- td_file = "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td",
+ td_file = "include/mlir/Dialect/Shard/Interfaces/ShardingInterface.td",
deps = [":OpBaseTdFiles"],
)
cc_library(
- name = "MeshShardingInterface",
- srcs = ["lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp"],
+ name = "ShardingInterface",
+ srcs = ["lib/Dialect/Shard/Interfaces/ShardingInterface.cpp"],
hdrs = [
- "include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h",
- "include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h",
+ "include/mlir/Dialect/Shard/Interfaces/ShardingInterface.h",
+ "include/mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h",
],
includes = ["include"],
deps = [
":DialectUtils",
":IR",
- ":MeshDialect",
- ":MeshShardingInterfaceIncGen",
+ ":ShardDialect",
+ ":ShardingInterfaceIncGen",
":Support",
"//llvm:Support",
],
)
cc_library(
- name = "MeshDialect",
- srcs = ["lib/Dialect/Mesh/IR/MeshOps.cpp"],
+ name = "ShardDialect",
+ srcs = ["lib/Dialect/Shard/IR/ShardOps.cpp"],
hdrs = [
- "include/mlir/Dialect/Mesh/IR/MeshDialect.h",
- "include/mlir/Dialect/Mesh/IR/MeshOps.h",
+ "include/mlir/Dialect/Shard/IR/ShardDialect.h",
+ "include/mlir/Dialect/Shard/IR/ShardOps.h",
],
includes = ["include"],
deps = [
@@ -3171,7 +3171,7 @@ cc_library(
":IR",
":InferTypeOpInterface",
":InliningUtils",
- ":MeshIncGen",
+ ":ShardIncGen",
":SideEffectInterfaces",
":Support",
":ViewLikeInterface",
@@ -3180,23 +3180,23 @@ cc_library(
)
gentbl_cc_library(
- name = "MeshTransformsPassIncGen",
- tbl_outs = {"include/mlir/Dialect/Mesh/Transforms/Passes.h.inc": [
+ name = "ShardTransformsPassIncGen",
+ tbl_outs = {"include/mlir/Dialect/Shard/Transforms/Passes.h.inc": [
"-gen-pass-decls",
- "-name=Mesh",
+ "-name=Shard",
]},
tblgen = ":mlir-tblgen",
- td_file = "include/mlir/Dialect/Mesh/Transforms/Passes.td",
+ td_file = "include/mlir/Dialect/Shard/Transforms/Passes.td",
deps = [":PassBaseTdFiles"],
)
cc_library(
- name = "MeshTransforms",
+ name = "ShardTransforms",
srcs = glob([
- "lib/Dialect/Mesh/Transforms/*.cpp",
- "lib/Dialect/Mesh/Transforms/*.h",
+ "lib/Dialect/Shard/Transforms/*.cpp",
+ "lib/Dialect/Shard/Transforms/*.h",
]),
- hdrs = glob(["include/mlir/Dialect/Mesh/Transforms/*.h"]),
+ hdrs = glob(["include/mlir/Dialect/Shard/Transforms/*.h"]),
includes = ["include"],
deps = [
":AffineDialect",
@@ -3209,9 +3209,9 @@ cc_library(
":FuncDialect",
":FunctionInterfaces",
":IR",
- ":MeshDialect",
- ":MeshShardingInterface",
- ":MeshTransformsPassIncGen",
+ ":ShardDialect",
+ ":ShardingInterface",
+ ":ShardTransformsPassIncGen",
":Pass",
":Support",
":TensorDialect",
@@ -3221,11 +3221,11 @@ cc_library(
)
cc_library(
- name = "MeshToMPIConversion",
+ name = "ShardToMPIConversion",
srcs = glob([
- "lib/Conversion/MeshToMPI/*.cpp",
+ "lib/Conversion/ShardToMPI/*.cpp",
]),
- hdrs = glob(["include/mlir/Conversion/MeshToMPI/*.h"]),
+ hdrs = glob(["include/mlir/Conversion/ShardToMPI/*.h"]),
includes = ["include"],
deps = [
":AffineDialect",
@@ -3240,8 +3240,8 @@ cc_library(
":LinalgDialect",
":MPIDialect",
":MemRefDialect",
- ":MeshDialect",
- ":MeshTransforms",
+ ":ShardDialect",
+ ":ShardTransforms",
":Pass",
":SCFDialect",
":Support",
@@ -3988,7 +3988,7 @@ cc_library(
":MemRefToEmitC",
":MemRefToLLVM",
":MemRefToSPIRV",
- ":MeshToMPIConversion",
+ ":ShardToMPIConversion",
":NVGPUToNVVM",
":NVVMToLLVM",
":OpenACCToSCF",
@@ -4522,7 +4522,7 @@ cc_library(
":FuncDialect",
":IR",
":InliningUtils",
- ":MeshShardingInterface",
+ ":ShardingInterface",
],
)
@@ -4621,7 +4621,7 @@ cc_library(
":MemRefToEmitC",
":MemRefToLLVM",
":MemRefTransformOps",
- ":MeshDialect",
+ ":ShardDialect",
":NVGPUTransformOps",
":NVVMTarget",
":NVVMToLLVM",
@@ -7194,7 +7194,7 @@ cc_library(
includes = ["include"],
deps = [
":IR",
- ":MeshShardingInterface",
+ ":ShardingInterface",
":TensorDialect",
"//llvm:Support",
],
@@ -9019,8 +9019,8 @@ cc_library(
":MemRefToSPIRV",
":MemRefTransformOps",
":MemRefTransforms",
- ":MeshDialect",
- ":MeshTransforms",
+ ":ShardDialect",
+ ":ShardTransforms",
":NVGPUDialect",
":NVGPUPassIncGen",
":NVGPUToNVVM",
@@ -9120,7 +9120,7 @@ cc_binary(
"//mlir/test:TestMath",
"//mlir/test:TestMathToVCIX",
"//mlir/test:TestMemRef",
- "//mlir/test:TestMesh",
+ "//mlir/test:TestShard",
"//mlir/test:TestNVGPU",
"//mlir/test:TestPDLL",
"//mlir/test:TestPass",
@@ -9182,7 +9182,7 @@ cc_binary(
"//mlir/test:TestMathToVCIX",
"//mlir/test:TestMemRef",
"//mlir/test:TestMemRefToLLVMWithTransforms",
- "//mlir/test:TestMesh",
+ "//mlir/test:TestShard",
"//mlir/test:TestNVGPU",
"//mlir/test:TestPDLL",
"//mlir/test:TestPass",
@@ -10548,7 +10548,7 @@ cc_library(
":LinalgStructuredOpsIncGen",
":MathDialect",
":MemRefDialect",
- ":MeshShardingInterface",
+ ":ShardingInterface",
":Parser",
":SCFDialect",
":SideEffectInterfaces",
@@ -10699,9 +10699,9 @@ cc_library(
":MathDialect",
":MemRefDialect",
":MemRefTransforms",
- ":MeshDialect",
- ":MeshShardingInterface",
- ":MeshTransforms",
+ ":ShardDialect",
+ ":ShardingInterface",
+ ":ShardTransforms",
":Pass",
":RuntimeVerifiableOpInterface",
":SCFDialect",
@@ -11198,8 +11198,8 @@ cc_library(
":InferTypeOpInterface",
":InliningUtils",
":LoopLikeInterface",
- ":MeshDialect",
- ":MeshShardingInterface",
+ ":ShardDialect",
+ ":ShardingInterface",
":Pass",
":QuantOps",
":SideEffectInterfaces",
@@ -12141,7 +12141,7 @@ cc_library(
":FuncTransforms",
":IR",
":MemRefDialect",
- ":MeshShardingInterface",
+ ":ShardingInterface",
":Pass",
":SideEffectInterfaces",
":TensorDialect",
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 95e3ee4df7bc5..e7770fcc9eabd 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -903,8 +903,8 @@ cc_library(
)
cc_library(
- name = "TestMesh",
- srcs = glob(["lib/Dialect/Mesh/**/*.cpp"]),
+ name = "TestShard",
+ srcs = glob(["lib/Dialect/Shard/**/*.cpp"]),
includes = ["lib/Dialect/Test"],
deps = [
":TestDialect",
@@ -912,8 +912,8 @@ cc_library(
"//mlir:DialectUtils",
"//mlir:FuncDialect",
"//mlir:IR",
- "//mlir:MeshDialect",
- "//mlir:MeshTransforms",
+ "//mlir:ShardDialect",
+ "//mlir:ShardTransforms",
"//mlir:Pass",
"//mlir:SPIRVDialect",
"//mlir:Support",
>From 354baf49b7188797f9571538ded69fd0ede1c193 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 24 Jul 2025 11:14:11 +0200
Subject: [PATCH 2/2] mention rename mesh->shard in shard doc
---
mlir/docs/Dialects/Shard.md | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/docs/Dialects/Shard.md b/mlir/docs/Dialects/Shard.md
index 714b340db4cde..153231d156edc 100644
--- a/mlir/docs/Dialects/Shard.md
+++ b/mlir/docs/Dialects/Shard.md
@@ -1,8 +1,10 @@
# 'shard' Dialect
-The `shard` dialect contains a set of attributes, operations and interfaces that
-are useful for representing sharding and communication on a device grid
-cluster.
+This dialect contains a set of attributes, operations and interfaces that
+are useful for representing sharding and communication on a device grid.
+
+It was originally introduced under the name 'mesh' but was later renamed
+to better reflect its purpose.
[TOC]
More information about the Mlir-commits
mailing list