[Mlir-commits] [mlir] [Mesh] initialize mesh dialect (PR #68007)
Chase Roberts
llvmlistbot at llvm.org
Wed Oct 4 22:50:48 PDT 2023
================
@@ -0,0 +1,128 @@
+//===- MeshBase.td - Mesh Dialect --------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_IR_MESHBASE_TD
+#define MLIR_DIALECT_MESH_IR_MESHBASE_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/EnumAttr.td"
+
+//===----------------------------------------------------------------------===//
+// Mesh Dialect
+//===----------------------------------------------------------------------===//
+
+def Mesh_Dialect : Dialect {
+ let name = "mesh";
+ let cppNamespace = "::mlir::mesh";
+
+ let description = [{
+ The `mesh` dialect contains a set of attributes, operations, interfaces that
+ are useful for representing sharding and communication on device mesh
+ cluster.
+ }];
+
+ let dependentDialects = [
+ "arith::ArithDialect" // For materializeConstant()
+ ];
+
+ let useDefaultAttributePrinterParser = 1;
+ let hasConstantMaterializer = 1;
+}
+//===----------------------------------------------------------------------===//
+// Mesh Enums.
+//===----------------------------------------------------------------------===//
+
+def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor", [
+ I32EnumAttrCase<"Sum", 1, "sum">,
+ I32EnumAttrCase<"Max", 2, "max">,
+ I32EnumAttrCase<"Min", 3, "min">,
+ I32EnumAttrCase<"Generic", 100, "generic">
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::mesh";
+}
+
+//===----------------------------------------------------------------------===//
+// Mesh Attribute
+//===----------------------------------------------------------------------===//
+
+def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
+ let mnemonic = "shard";
+
+ let parameters = (ins
+ AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
+ ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$split_axes,
+ OptionalArrayRefParameter<"int8_t">:$partial_axes,
+ OptionalParameter<"::mlir::mesh::Partial">:$partial_type
+ );
+
+ let summary = "Attribute that extends tensor type to distributed tensor type.";
+
+ let description = [{
+ The MeshSharding attribute could be used in the encoding of a
+ `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
+
+ 1. `cluster`: this attribute is a SymbolRefAttr that refers to the mesh
+ cluster where the distributed tensor is placed. The symbol must resolve to a
+ `mesh.cluster` operation.
+
+ 2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
+ maximum size is the `rank` of the related tensor. For the i-th sub-array, if
+ its value is [x, y], it indicates that the tensor's i-th dimension is splitted
+ along the x and y axes of the device mesh.
+
+ 3. `partial_axes`: if not empty, this sifnifies that the tensor is partial
----------------
chaserileyroberts wrote:
```suggestion
3. `partial_axes`: if not empty, this signifies that the tensor is partial
```
https://github.com/llvm/llvm-project/pull/68007
More information about the Mlir-commits
mailing list