[Mlir-commits] [mlir] 08545e8 - [MLIR] Add a new Mesh dialect (#68007)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 10 11:35:46 PDT 2023
Author: Chengji Yao
Date: 2023-10-10T11:35:40-07:00
New Revision: 08545e85167a105b8147d76a48a2fa1eac0f9e9a
URL: https://github.com/llvm/llvm-project/commit/08545e85167a105b8147d76a48a2fa1eac0f9e9a
DIFF: https://github.com/llvm/llvm-project/commit/08545e85167a105b8147d76a48a2fa1eac0f9e9a.diff
LOG: [MLIR] Add a new Mesh dialect (#68007)
This is the 1st PR of [Mesh sharding
RFC](https://discourse.llvm.org/t/open-mlir-meeting-9-28-2023-rfc-sharding-framework-design-for-device-mesh/73695),
includes
Includes:
- mesh.cluster op
- mesh.shard op (the mesh.annotate op in the RFC slides, the name is
modified a bit from @stellaraccident 's advice, which I think might be a
bit more concise)
- MeshSharding attribute
Added:
mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
mlir/lib/Dialect/Mesh/CMakeLists.txt
mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
mlir/test/Dialect/Mesh/invalid.mlir
mlir/test/Dialect/Mesh/ops.mlir
Modified:
mlir/include/mlir/Dialect/CMakeLists.txt
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 9082d9633339c9f..1c4569ecfa58485 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -19,6 +19,7 @@ add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(Math)
add_subdirectory(MemRef)
+add_subdirectory(Mesh)
add_subdirectory(MLProgram)
add_subdirectory(NVGPU)
add_subdirectory(OpenACC)
diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
new file mode 100644
index 000000000000000..f33061b2d87cffc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
new file mode 100644
index 000000000000000..cfc948e305638fa
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect(MeshOps mesh)
+add_mlir_doc(MeshOps MeshOps Dialects/ -gen-op-doc)
+
+set(LLVM_TARGET_DEFINITIONS MeshBase.td)
+mlir_tablegen(MeshOpsAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(MeshOpsAttributes.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRMeshOpsAttrIncGen)
+add_mlir_doc(MeshOps MeshAttributes Dialects/ -gen-attrdef-doc)
+
+set(LLVM_TARGET_DEFINITIONS MeshBase.td)
+mlir_tablegen(MeshOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(MeshOpsEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRMeshOpsEnumsIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
new file mode 100644
index 000000000000000..d761743a82bf86b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -0,0 +1,128 @@
+//===- MeshBase.td - Mesh Dialect --------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_IR_MESHBASE_TD
+#define MLIR_DIALECT_MESH_IR_MESHBASE_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/EnumAttr.td"
+
+//===----------------------------------------------------------------------===//
+// Mesh Dialect
+//===----------------------------------------------------------------------===//
+
+def Mesh_Dialect : Dialect {
+ let name = "mesh";
+ let cppNamespace = "::mlir::mesh";
+
+ let description = [{
+ The `mesh` dialect contains a set of attributes, operations, interfaces that
+ are useful for representing sharding and communication on device mesh
+ cluster.
+ }];
+
+ let dependentDialects = [
+ "arith::ArithDialect" // For materializeConstant()
+ ];
+
+ let useDefaultAttributePrinterParser = 1;
+ let hasConstantMaterializer = 1;
+}
+//===----------------------------------------------------------------------===//
+// Mesh Enums.
+//===----------------------------------------------------------------------===//
+
+def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor", [
+ I32EnumAttrCase<"Sum", 1, "sum">,
+ I32EnumAttrCase<"Max", 2, "max">,
+ I32EnumAttrCase<"Min", 3, "min">,
+ I32EnumAttrCase<"Generic", 100, "generic">
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::mesh";
+}
+
+//===----------------------------------------------------------------------===//
+// Mesh Attribute
+//===----------------------------------------------------------------------===//
+
+def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
+ let mnemonic = "shard";
+
+ let parameters = (ins
+ AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
+ ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$split_axes,
+ OptionalArrayRefParameter<"int8_t">:$partial_axes,
+ OptionalParameter<"::mlir::mesh::Partial">:$partial_type
+ );
+
+ let summary = "Attribute that extends tensor type to distributed tensor type.";
+
+ let description = [{
+ The MeshSharding attribute could be used in the encoding of a
+ `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
+
+ 1. `cluster`: this attribute is a SymbolRefAttr that refers to the mesh
+ cluster where the distributed tensor is placed. The symbol must resolve to a
+ `mesh.cluster` operation.
+
+ 2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
+ maximum size is the `rank` of the related tensor. For the i-th sub-array, if
+ its value is [x, y], it indicates that the tensor's i-th dimension is splitted
+ along the x and y axes of the device mesh.
+
+ 3. `partial_axes`: if not empty, this signifies that the tensor is partial
+ one along the specified mesh axes. An all-reduce should be applied to obtain
+ the complete tensor, with reduction type being specified by `partial_type`.
+
+ 4. `partial_type`: indicates the reduction type of the possible all-reduce
+ op. It has 4 possible values:
+ - `partial_sum`: denotes it's an all-reduce-sum
+ - `partial_max`: denotes it's an all-reduce-max
+ - `partial_min`: denotes it's an all-reduce-min
+ - `partial_generic`: denotes that the all-reduce type is complex and cannot
+ be represented merely by a simple sum, max, or min. The exact reduction
+ computation may be derived from the semantics of the corresponding operation
+ or from the reduction computation IR
+
+ Example:
+
+ ```
+ mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
+
+ // The tensor is fully replicated on @mesh0.
+ // Currently, there must be at least one sub-array present in axes, even
+ // if it's empty. Otherwise, a parsing error will occur.
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
+
+ // The tensor is sharded on the first dimension along axis 0 of @mesh0
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>
+
+ // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
+ // it is also a partial_sum along mesh axis 1.
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>
+
+ // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
+ // it is also a partial_max along mesh axis 1.
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial = max[1]>
+
+ // Could be used in the attribute of mesh.shard op
+ %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+ ```
+ }];
+ let assemblyFormat = [{
+ `<` $cluster `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
+ $partial_axes^ `]`)? `>`
+ }];
+
+ let genVerifyDecl = 1;
+}
+
+#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
new file mode 100644
index 000000000000000..9dfeca84d012165
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -0,0 +1,27 @@
+//===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
+#define MLIR_DIALECT_MESH_IR_MESHOPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
+
+#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
+
+#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
new file mode 100644
index 000000000000000..8ca4b6653104221
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -0,0 +1,174 @@
+//===-- MeshOps.td - Mesh dialect operation definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_TD
+#define MLIR_DIALECT_MESH_IR_MESHOPS_TD
+
+include "mlir/Dialect/Mesh/IR/MeshBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/SymbolInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// Mesh Dialect operations.
+//===----------------------------------------------------------------------===//
+
+class Mesh_Op<string mnemonic, list<Trait> traits = []> :
+ Op<Mesh_Dialect, mnemonic, traits> {
+}
+
+def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
+ let summary = "representing a mesh cluster";
+ let description = [{
+ The mesh.cluster operation is a symbol operation that identifies a specific
+ mesh cluster. The operation has three attributes:
+
+ 1. `sym_name`: This attribute uniquely identifies the name of the mesh
+ cluster. This name serves as a symbolic reference to the cluster throughout
+ the MLIR module, allowing for consistent referencing and easier debugging.
+
+ 2. `rank`: This attribute specifies the number of axes of the cluster. The
+ rank indicates the dimensionality of the mesh cluster and can be used to
+ determine the layout and the addressing space of the computation distributed
+ across the mesh.
+
+ 3. `dim_sizes`: This attribute represents the device assignment along the
+ axes of the cluster. Each integer in the array corresponds to the number of
+ devices along a specific axis. If an integer value is 0, it implies that the
+ number of devices along that axis is unknown. This flexibility allows for
+ dynamic device assignment or configurations where the exact number of
+ devices might not be determined during compile time.
+
+ Example:
+ ```
+ // A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
+ // The dimension sizes are 4, 8, 12
+ mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12])
+
+ // A device mesh cluster with 2 axes, the total device number is unknown
+ // The first dimension size is 4 and the second is unknown
+ mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
+
+ // A device mesh cluster with 2 axes, the total device number is unknown
+ // The first dimension size is unknown and the second is 4
+ mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
+
+ // A device mesh cluster with 2 axes, the number of devices along both axes
+ // is unknown
+ mesh.cluster @mesh3(rank = 2)
+
+ // Used in the mesh sharding attribute to extend the standard tensor to
+ // distributed
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
+ ```
+ }];
+ let arguments = (ins
+ SymbolNameAttr:$sym_name,
+ I8Attr:$rank,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
+ );
+ let assemblyFormat = [{
+ $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
+ attr-dict
+ }];
+ let hasVerifier = 1;
+}
+
+def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
+ let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
+ let description = [{
+ The mesh.shard operation is designed to specify and guide the sharding
+ behavior of a tensor value across a mesh topology. This operation has one
+ operand and two attributes:
+
+ 1. `input`: This operand represents the tensor value that needs to be
+ annotated for sharding.
+
+ 2. `shard`: This attribute is type of `MeshSharding`, which is the core data
+ structure to represent distributed tensor in mesh cluster.
+
+ 3. `annotate_for_users`: A unit attribute addressing the scenario when a
+ tensor's sharding annotation
diff ers based on its context of use (either as
+ a result or an operand). If specified, the sharding pertains to specific
+ users of the tensor value, indicating how it should be considered when used
+ as an operand in subsequent operations. If not, the sharding applies to the
+ operation that defines the tensor value.
+
+ Example:
+ ```
+ func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
+ %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+ ...
+ }
+
+ func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
+ %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
+ ...
+ }
+
+ // The first mesh.shard op applies to %arg0, the second mesh.shard op
+ // applies for the operand of op0, the third mesh.shard op applies for the
+ // operand of op2
+ func.func @both_result_and_multi_operands_annotated(
+ %arg0 : tensor<4x8xf32>) -> () {
+ %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+ %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
+ %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
+ "op0"(%1) : ...
+ "op1"(%2) : ...
+ ...
+ }
+ ```
+
+ The following usages are undefined:
+ ```
+ func.func @annotate_on_same_result_with_
diff erent_sharding(
+ %arg0 : tensor<4x8xf32>) -> () {
+ %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+ %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
+ ...
+ }
+
+ func.func @annotate_on_same_result_same_value_with_
diff erent_sharding(
+ %arg0 : tensor<4x8xf32>) -> () {
+ %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+ %1 = mesh.shard %arg0 to <@mesh0, [[1]]> : tensor<4x8xf32>
+ ...
+ }
+
+ func.func @annotate_on_same_operand_with_
diff erent_sharding(
+ %arg0 : tensor<4x8xf32>) -> () {
+ %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
+ %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
+ ...
+ }
+
+ func.func @result_annotated_after_operand(
+ %arg0 : tensor<4x8xf32>) -> () {
+ %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
+ %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
+ ...
+ }
+ ```
+ }];
+ let arguments = (ins
+ Builtin_RankedTensor:$src,
+ MeshSharding:$shard,
+ UnitAttr:$annotate_for_users
+ );
+ let results = (outs
+ Builtin_RankedTensor:$result
+ );
+ let assemblyFormat = [{
+ $src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`
+ type($result)
+ }];
+}
+
+#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index f1fabf95a68b7ad..379392ace46961a 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -350,7 +350,8 @@ template <typename AsmPrinterT, typename T,
!std::is_convertible<T &, Attribute &>::value &&
!std::is_convertible<T &, ValueRange>::value &&
!std::is_convertible<T &, APFloat &>::value &&
- !llvm::is_one_of<T, bool, float, double>::value,
+ !llvm::is_one_of<T, bool, int8_t, uint8_t, float,
+ double>::value,
T> * = nullptr>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
@@ -366,6 +367,17 @@ operator<<(AsmPrinterT &p, bool value) {
return p << (value ? StringRef("true") : "false");
}
+/// Specialization for 8-bit integers to ensure values are printed as integers
+// and not characters.
+template <
+ typename AsmPrinterT, typename T,
+ std::enable_if_t<llvm::is_one_of<T, int8_t, uint8_t>::value, T> * = nullptr>
+inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
+ AsmPrinterT &>
+operator<<(AsmPrinterT &p, T value) {
+ return p << static_cast<int16_t>(value);
+}
+
template <typename AsmPrinterT, typename ValueRangeT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index d04ed373ecf045a..00f400aab5d50a0 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -55,6 +55,7 @@
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -117,6 +118,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
LLVM::LLVMDialect,
math::MathDialect,
memref::MemRefDialect,
+ mesh::MeshDialect,
ml_program::MLProgramDialect,
nvgpu::NVGPUDialect,
NVVM::NVVMDialect,
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index d9f6b0fb7e63c2a..68776a695cac4d4 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -19,6 +19,7 @@ add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(Math)
add_subdirectory(MemRef)
+add_subdirectory(Mesh)
add_subdirectory(MLProgram)
add_subdirectory(NVGPU)
add_subdirectory(OpenACC)
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Mesh/CMakeLists.txt
new file mode 100644
index 000000000000000..f33061b2d87cffc
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
new file mode 100644
index 000000000000000..700e6e21f36b677
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_dialect_library(MLIRMeshDialect
+ MeshOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+
+ DEPENDS
+ MLIRMeshOpsAttrIncGen
+ MLIRMeshOpsEnumsIncGen
+ MLIRMeshOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRIR
+ MLIRSupport
+)
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
new file mode 100644
index 000000000000000..b2a47102528758c
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -0,0 +1,112 @@
+//===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#define DEBUG_TYPE "mesh-ops"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// Mesh dialect
+//===----------------------------------------------------------------------===//
+
+void MeshDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
+ >();
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
+ >();
+}
+
+Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
+ Type type, Location loc) {
+ return arith::ConstantOp::materialize(builder, value, type, loc);
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.cluster op
+//===----------------------------------------------------------------------===//
+
+LogicalResult ClusterOp::verify() {
+ ArrayRef<int64_t> dimSizes = getDimSizes();
+ uint8_t rank = getRank();
+
+ if (rank == 0)
+ return emitOpError("rank of cluster is expected to be a positive integer");
+
+ if (dimSizes.size() > rank)
+ return emitOpError(
+ "rank of dim_sizes is not expected to be larger than rank of cluster");
+
+ for (int64_t dimSize : dimSizes) {
+ if (dimSize < 0)
+ return emitOpError(
+ "dimension size of a mesh cluster is expected to be non-negative");
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shard op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ SymbolRefAttr, ArrayRef<DenseI8ArrayAttr> splitAxes,
+ ArrayRef<int8_t> partialAxes, Partial) {
+ // TODO: At present cluster symbol ref is not verified. This is due to the
+ //
diff iculty in fetching the corresponding symbol op based on an attribute.
+
+ llvm::SmallSet<int8_t, 4> visitedAxes;
+
+ auto checkMeshAxis = [&](ArrayRef<int8_t> axesArray) -> LogicalResult {
+ for (int8_t axis : axesArray) {
+ if (axis < 0)
+ return emitError() << "mesh axis is expected to be non-negative";
+ if (!visitedAxes.insert(axis).second)
+ return emitError() << "mesh axis duplicated";
+ }
+ return success();
+ };
+
+ for (DenseI8ArrayAttr subAxes : splitAxes) {
+ ArrayRef<int8_t> subAxesArray = subAxes.asArrayRef();
+ if (failed(checkMeshAxis(subAxesArray)))
+ return failure();
+ }
+ if (failed(checkMeshAxis(partialAxes)))
+ return failure();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
+
+#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.cpp.inc"
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
new file mode 100644
index 000000000000000..246439dd4be7122
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -0,0 +1,69 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+
+// expected-error at +1 {{rank of cluster is expected to be a positive integer}}
+mesh.cluster @mesh0(rank = 0)
+
+// -----
+
+// expected-error at +1 {{rank of dim_sizes is not expected to be larger than rank of cluster}}
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 3, 4])
+
+// -----
+
+// expected-error at +1 {{dimension size of a mesh cluster is expected to be non-negative}}
+mesh.cluster @mesh0(rank = 2, dim_sizes = [-1])
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @mesh_axis_duplicated_
diff erent_subarray(
+ // expected-error at +1 {{mesh axis duplicated}}
+ %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>) ->
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>> {
+ return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @mesh_axis_duplicated_same_subarray(
+ // expected-error at +1 {{mesh axis duplicated}}
+ %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>) ->
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>> {
+ return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @mesh_axis_duplicated_bewteen_split_and_partial(
+ // expected-error at +1 {{mesh axis duplicated}}
+ %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>>) ->
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>> {
+ return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @mesh_axis_negtive_in_split_part(
+ // expected-error at +1 {{mesh axis is expected to be non-negative}}
+ %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>) ->
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>> {
+ return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @mesh_axis_negtive_in_partial(
+ // expected-error at +1 {{mesh axis is expected to be non-negative}}
+ %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>) ->
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
+ return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
+}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
new file mode 100644
index 000000000000000..ee5f8f67792b928
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -0,0 +1,128 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// CHECK: mesh.cluster @mesh0
+mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
+
+// CHECK: mesh.cluster @mesh1
+mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
+
+// CHECK: mesh.cluster @mesh2
+mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
+
+// CHECK: mesh.cluster @mesh3
+mesh.cluster @mesh3(rank = 2)
+
+// CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
+func.func @mesh_shard_encoding_fully_replicated(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
+ %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>) ->
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>> {
+ return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
+}
+
+// CHECK-LABEL: func @mesh_shard_encoding_1st_dim
+func.func @mesh_shard_encoding_1st_dim(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}0]]>>
+ %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>) ->
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>> {
+ return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
+}
+
+// CHECK-LABEL: func @mesh_shard_encoding_2nd_dim
+func.func @mesh_shard_encoding_2nd_dim(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh1, {{\[\[}}], [0]]>>
+ %arg0 : tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>>) ->
+ tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>> {
+ return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>>
+}
+
+// CHECK-LABEL: func @mesh_shard_encoding_1st_and_3rd_dim
+func.func @mesh_shard_encoding_1st_and_3rd_dim(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32, #mesh.shard<@mesh3, {{\[\[}}0], [], [1]]>>
+ %arg0 : tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>>) ->
+ tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>> {
+ return %arg0 : tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_1st_dim
+// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}0]]> : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+ return %0 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_2nd_dim
+// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh1, {{\[\[}}], [0]]> : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to <@mesh1, [[], [0]]> : tensor<4x8xf32>
+ return %0 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_1st_and_3rd_dim
+func.func @mesh_shard_op_1st_and_3rd_dim(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32>
+ %arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+ // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0], [], [1]]> : tensor<4x8x16xf32>
+ %0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]]> : tensor<4x8x16xf32>
+ return %0 : tensor<4x8x16xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_partial_max
+func.func @mesh_shard_op_partial_max(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+ %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = max[1]> : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = max[1]> : tensor<4x8xf32>
+ return %0 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_partial_min
+func.func @mesh_shard_op_partial_min(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+ %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = min[1]> : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = min[1]> : tensor<4x8xf32>
+ return %0 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_partial_generic
+func.func @mesh_shard_op_partial_generic(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+ %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = generic[1]> : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = generic[1]> : tensor<4x8xf32>
+ return %0 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_partial_sum
+func.func @mesh_shard_op_partial_sum(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+ %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = sum[1]> : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = sum[1]> : tensor<4x8xf32>
+ return %0 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_partial_sum_multi_axes
+func.func @mesh_shard_op_partial_sum_multi_axes(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+ %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+ // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = sum[1, 2]> : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = sum[1, 2]> : tensor<4x8xf32>
+ return %0 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_two_users
+// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
+ (tensor<4x8xf32>, tensor<4x8xf32>) {
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}0]]> : tensor<4x8xf32>
+ %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+ // CHECK-DAG: mesh.shard %[[V0]] to <@mesh0, {{\[\[}}1]]> annotate_for_users : tensor<4x8xf32>
+ %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
+ // CHECK-DAG: mesh.shard %[[V0]] to <@mesh0, {{\[\[}}2]]> annotate_for_users : tensor<4x8xf32>
+ %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
+ return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
+}
More information about the Mlir-commits
mailing list