[Mlir-commits] [mlir] [Mesh] initialize mesh dialect (PR #68007)

Chengji Yao llvmlistbot at llvm.org
Mon Oct 2 15:46:14 PDT 2023


================
@@ -0,0 +1,91 @@
+//===- MeshBase.td - Mesh Dialect --------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_IR_MESHBASE_TD
+#define MLIR_DIALECT_MESH_IR_MESHBASE_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// Mesh Dialect
+//===----------------------------------------------------------------------===//
+
+def Mesh_Dialect : Dialect {
+  let name = "mesh";
+  let cppNamespace = "::mlir::mesh";
+
+  let description = [{
+    The `mesh` dialect contains a set of attributes, operations, interfaces that
+    are useful for representing sharding and communication on device mesh
+    cluster.
+  }];
+
+  let dependentDialects = [
+    "arith::ArithDialect"
+  ];
+
+  let useDefaultAttributePrinterParser = 1;
+  let hasConstantMaterializer = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Mesh Attribute
+//===----------------------------------------------------------------------===//
+
+def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
+  let mnemonic = "shard";
+
+  let parameters = (ins
+    OptionalParameter<"::mlir::SymbolRefAttr">:$cluster,
----------------
yaochengji wrote:

I just remembered another reason to make the `cluster` optional in mesh sharding attribute is for function reuse, where the cluster info is actually in `func.call` op.

```
func.func @foo() ...

func.func @main() {
    ...
    %1 = func.call @foo(%0) { mesh_cluster = @mesh0 } ...
    %2 = func.call @foo(%1) { mesh_cluster = @mesh1 } ...
    %3 = func.call @foo(%2) { mesh_cluster = @mesh2 } ...
    ...
}
```

This should be useful in pipeline parallel. Here we only need to add one function instead of three. The additional logic can be summarized as follows: Upon encountering an enclosing FunctionInterface operation without any associated cluster information, all corresponding func.call operations will be examined for cluster data.  It may be prudent to introduce some verification logic. If a cluster is identified in one func.call operation, then all func.call operations should contain the cluster information. Meanwhile, all the clusters should have exactly the same rank and `dim_sizes`.

https://github.com/llvm/llvm-project/pull/68007


More information about the Mlir-commits mailing list