[Mlir-commits] [mlir] [Mesh] initialize mesh dialect (PR #68007)

Chengji Yao llvmlistbot at llvm.org
Tue Oct 3 15:05:25 PDT 2023


https://github.com/yaochengji updated https://github.com/llvm/llvm-project/pull/68007

>From 655ca9e9775fdded3c0e479a6bf6d3b82f24c5d2 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmial.com>
Date: Sat, 30 Sep 2023 00:24:38 +0000
Subject: [PATCH 1/4] [Mesh] initialize mesh dialect

---
 mlir/include/mlir/Dialect/CMakeLists.txt      |   1 +
 mlir/include/mlir/Dialect/Mesh/CMakeLists.txt |   1 +
 .../mlir/Dialect/Mesh/IR/CMakeLists.txt       |   8 ++
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |  91 ++++++++++++
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   |  35 +++++
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 132 ++++++++++++++++++
 mlir/include/mlir/InitAllDialects.h           |   2 +
 mlir/lib/Dialect/CMakeLists.txt               |   1 +
 mlir/lib/Dialect/Mesh/CMakeLists.txt          |  15 ++
 mlir/lib/Dialect/Mesh/MeshOps.cpp             | 122 ++++++++++++++++
 mlir/test/Dialect/Mesh/invalid.mlir           |  65 +++++++++
 mlir/test/Dialect/Mesh/ops.mlir               |  69 +++++++++
 12 files changed, 542 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
 create mode 100644 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
 create mode 100644 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
 create mode 100644 mlir/lib/Dialect/Mesh/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Mesh/MeshOps.cpp
 create mode 100644 mlir/test/Dialect/Mesh/invalid.mlir
 create mode 100644 mlir/test/Dialect/Mesh/ops.mlir

diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 9082d9633339c9f..1c4569ecfa58485 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -19,6 +19,7 @@ add_subdirectory(Linalg)
 add_subdirectory(LLVMIR)
 add_subdirectory(Math)
 add_subdirectory(MemRef)
+add_subdirectory(Mesh)
 add_subdirectory(MLProgram)
 add_subdirectory(NVGPU)
 add_subdirectory(OpenACC)
diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
new file mode 100644
index 000000000000000..f33061b2d87cffc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
new file mode 100644
index 000000000000000..f1d15891d029e1a
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_dialect(MeshOps mesh)
+add_mlir_doc(MeshOps MeshOps Dialects/ -gen-op-doc)
+
+set(LLVM_TARGET_DEFINITIONS MeshBase.td)
+mlir_tablegen(MeshOpsAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(MeshOpsAttributes.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRMeshOpsAttrIncGen)
+add_mlir_doc(MeshOps MeshAttributes Dialects/ -gen-attrdef-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
new file mode 100644
index 000000000000000..05b5eb75c66d3f2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -0,0 +1,91 @@
+//===- MeshBae.td - Mesh Dialect Dialect -------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_IR_MESHBASE_TD
+#define MLIR_DIALECT_MESH_IR_MESHBASE_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// Mesh Dialect
+//===----------------------------------------------------------------------===//
+
+def Mesh_Dialect : Dialect {
+  let name = "mesh";
+  let cppNamespace = "::mlir::mesh";
+
+  let description = [{
+    The `mesh` dialect contains a set of attributes, operations, interfaces that
+    are useful for representing sharding and communication on device mesh
+    cluster.
+  }];
+
+  let dependentDialects = [
+    "arith::ArithDialect"
+  ];
+
+  let useDefaultAttributePrinterParser = 1;
+  let hasConstantMaterializer = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Mesh Attribute
+//===----------------------------------------------------------------------===//
+
+def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
+  let mnemonic = "shard";
+
+  let parameters = (ins
+    OptionalParameter<"::mlir::SymbolRefAttr">:$cluster,
+    ArrayRefParameter<"::mlir::DenseI64ArrayAttr">:$axes
+  );
+
+  let summary = "Attribute that extends tensor type to distributed tensor type.";
+
+  let description = [{
+    The mesh.shard attribute contains two attribute in it:
+    1. `cluster`: this attribute is a SymbolRefAttr that refers to the mesh
+    cluster where the distributed tensor is placed.
+
+    2. `axes`: is an array composed of int64_t sub-arrays. The outer array's
+    maximum size is the `rank` of the related tensor plus one. For the i-th
+    sub-array, if its value is [x, y]:
+    - When i < `rank`, it indicates that the tensor's i-th dimension is sharded
+    along the x and y axes of the device mesh.
+    - When i == `rank`, it signifies that the tensor represents a partial sum
+    along the x and y axes. More partial types could be introduced if needed, 
+    e.g. partial-max, partial-min and even partial-generic.
+
+    Example:
+
+    ```
+    mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
+
+    // The tensor is fully replicated on @mesh0.
+    // Currently, there must be at least one sub-array present in axes, even
+    // if it's empty. Otherwise, a parsing error will occur.
+    tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
+
+    // The tensor is sharded on the first dimension along axis 0 of @mesh0
+    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>
+
+    // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
+    // it is also a partial-sum along axis 1.
+    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>
+
+    // Could also be used in the attribute of mesh.shard op
+    %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+    ```
+  }];
+  let assemblyFormat = "`<` ($cluster^ `,`)? `[` $axes `]` `>`";
+  let genVerifyDecl = 1;
+}
+
+#endif // MLIR_DIALECT_MESH_IR_MESHBASE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
new file mode 100644
index 000000000000000..69159fc0d8cca89
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -0,0 +1,35 @@
+//===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
+#define MLIR_DIALECT_MESH_IR_MESHOPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+namespace mlir {
+
+namespace mesh {
+
+constexpr StringRef getMeshClusterAttrName() { return "mesh_cluster"; }
+
+} // namespace mesh
+
+} // namespace mlir
+
+#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
+
+#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
new file mode 100644
index 000000000000000..bbd4891a2947517
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -0,0 +1,132 @@
+//===-- MeshOps.td - Mesh dialect operation definitions ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_IR_MESHOPS_TD
+#define MLIR_DIALECT_MESH_IR_MESHOPS_TD
+
+include "mlir/Dialect/Mesh/IR/MeshBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/SymbolInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// Mesh Dialect operations.
+//===----------------------------------------------------------------------===//
+
+class Mesh_Op<string mnemonic, list<Trait> traits = []> :
+    Op<Mesh_Dialect, mnemonic, traits> {
+}
+
+def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
+  let summary = "representing a mesh cluster";
+  let description = [{
+    The mesh.cluster operation is a symbol operation that identifies a specific
+    mesh cluster. The operation has three attributes:
+
+    1. `sym_name`: This attribute uniquely identifies the name of the mesh
+    cluster. This name serves as a symbolic reference to the cluster throughout
+    the MLIR module, allowing for consistent referencing and easier debugging.
+
+    2. `rank`: This attribute specifies the number of axes of the cluster. The
+    rank indicates the dimensionality of the mesh cluster and can be used to
+    determine the layout and the addressing space of the computation distributed
+    across the mesh.
+
+    3. `dim_sizes`: This attribute represents the device assignment along the
+    axes of the cluster. Each integer in the array corresponds to the number of
+    devices along a specific axis. If an integer value is <= 0, it implies that
+    the number of devices along that axis is unknown. This flexibility allows
+    for dynamic device assignment or configurations where the exact number of
+    devices might not be determined during compile time.
+
+    Example:
+    ```
+    // A device mesh cluster with 3 axes, the totol device number is 4 * 8 * 12
+    // The dimension sizes are 4, 8, 12 
+    mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12])
+
+    // A device mesh cluster with 2 axes, the totol device number is unknown
+    // The first dimension size is 4 and the second is unknown
+    mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
+
+    // A device mesh cluster with 2 axes, the totol device number is unknown
+    // The first dimension size is unknown and the second is 4
+    mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
+
+    // A device mesh cluster with 2 axes, the number of devices along both axes
+    // is unknown
+    mesh.cluster @mesh3(rank = 2)
+
+    // A func op's default mesh cluster is @mesh0
+    func.func @foo() -> () attributes { mesh_cluster = @mesh0 } {
+      ...
+    }
+
+    // Used in the mesh sharding attribute to extend the standard tensor to
+    // distributed
+    tensor<4x8xf32, #mesh.shard<[[0]] @mesh0>
+    ```
+  }];
+  let arguments = (ins
+    SymbolNameAttr:$sym_name,
+    I64Attr:$rank,
+    DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
+  );
+  let assemblyFormat = "$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)` attr-dict";
+  let hasVerifier = 1;
+}
+
+def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
+  let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
+  let description = [{
+    The mesh.shard operation is designed to specify and guide the sharding
+    behavior of a tensor value across a mesh topology. This operation has one
+    operand and two attributes:
+
+    1. `input`: This operand represents the tensor value that needs to be
+    annotated for sharding.
+
+    2. `shard`: This attribute is type of `MeshSharding`, which is the core data
+    structure to represent distributed tensor in mesh cluster.
+
+    3. `as_result`: A boolean attribute addressing the scenario when a tensor's
+    sharding annotation differs based on its context of use (either as a result
+    or an operand). If true, the sharding applies to the operation that defines
+    the tensor value. If false, the sharding pertains to specific users of the
+    tensor value, indicating how it should be considered when used as an operand
+    in subsequent operations.
+
+    Example:
+    ```
+    // The first mesh.shard op applies to %arg0, the second mesh.shard op
+    // applies for the operand of op0, the third mesh.shard op applies for the
+    // operand of op2
+    func.func @two_users(%arg0 : tensor<4x8xf32>) -> () {
+      %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+      %1 = mesh.shard %0 to <@mesh0, [[1]]> {as_result = false} : tensor<4x8xf32>
+      %2 = mesh.shard %0 to <@mesh0, [[2]]> {as_result = false} : tensor<4x8xf32>
+      "op0"(%1) : ...
+      "op1"(%2) : ...
+      ...
+    }
+    ```
+  }];
+  let arguments = (ins
+    Builtin_RankedTensor:$src,
+    MeshSharding:$shard,
+    DefaultValuedAttr<BoolAttr, "true">:$as_result
+  );
+  let results = (outs
+    Builtin_RankedTensor:$result
+  );
+  let assemblyFormat = "$src `to` $shard attr-dict `:` type($result)";
+  let hasVerifier = 1;
+}
+
+#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 5ec36a7f289e586..cbdc29020367a01 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -55,6 +55,7 @@
 #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -117,6 +118,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   LLVM::LLVMDialect,
                   math::MathDialect,
                   memref::MemRefDialect,
+                  mesh::MeshDialect,
                   ml_program::MLProgramDialect,
                   nvgpu::NVGPUDialect,
                   NVVM::NVVMDialect,
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index d9f6b0fb7e63c2a..68776a695cac4d4 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -19,6 +19,7 @@ add_subdirectory(Linalg)
 add_subdirectory(LLVMIR)
 add_subdirectory(Math)
 add_subdirectory(MemRef)
+add_subdirectory(Mesh)
 add_subdirectory(MLProgram)
 add_subdirectory(NVGPU)
 add_subdirectory(OpenACC)
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Mesh/CMakeLists.txt
new file mode 100644
index 000000000000000..fc3f7c96f05c77f
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIRMeshDialect
+  MeshOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+
+  DEPENDS
+  MLIRMeshOpsAttrIncGen
+  MLIRMeshOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIRIR
+  MLIRSupport
+)
diff --git a/mlir/lib/Dialect/Mesh/MeshOps.cpp b/mlir/lib/Dialect/Mesh/MeshOps.cpp
new file mode 100644
index 000000000000000..fd9f332e204de45
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/MeshOps.cpp
@@ -0,0 +1,122 @@
+//===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#define DEBUG_TYPE "mesh-ops"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// Mesh dialect
+//===----------------------------------------------------------------------===//
+
+void MeshDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
+      >();
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
+      >();
+}
+
+Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
+                                            Type type, Location loc) {
+  return arith::ConstantOp::materialize(builder, value, type, loc);
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.cluster op
+//===----------------------------------------------------------------------===//
+
+LogicalResult ClusterOp::verify() {
+  ArrayRef<int64_t> dimSizes = getDimSizes();
+  size_t rank = getRank();
+
+  if (rank == 0)
+    return emitOpError("rank of cluster is expected to be a positive integer");
+
+  if (dimSizes.size() > rank)
+    return emitOpError(
+        "rank of dim_sizes is not expected to be larger than rank of cluster");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shard op
+//===----------------------------------------------------------------------===//
+
+LogicalResult ShardOp::verify() {
+  bool asResult = getAsResult();
+  if (asResult) {
+    Value src = getSrc();
+    Operation *defOp = src.getDefiningOp();
+    if (llvm::isa_and_nonnull<ShardOp>(defOp))
+      return emitOpError("two mesh.shard ops with as_result = true are not "
+                         "expected to be stacked together");
+
+    unsigned numShard = llvm::count_if(src.getUsers(), [](Operation *user) {
+      return llvm::isa<ShardOp>(user);
+    });
+    if (numShard > 1)
+      return emitOpError(
+          "when than one mesh.shard ops operate on the same tensor, all of "
+          "their as_result attributes are expected to be false");
+
+  } else {
+    ShardOp defShardOp = getSrc().getDefiningOp<ShardOp>();
+    if (defShardOp && !defShardOp.getAsResult())
+      return emitOpError("two mesh.shard ops with as_result = false are not "
+                         "expected to be stacked together");
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shard op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                         SymbolRefAttr, ArrayRef<DenseI64ArrayAttr> axes) {
+  // TODO: At present cluster symbol ref is not verified. This is due to the
+  // difficulty in fetching the corresponding symbol op based on an attribute.
+
+  DenseSet<int64_t> visitedAxes;
+  for (DenseI64ArrayAttr subAxes : axes) {
+    ArrayRef<int64_t> subAxesArray = subAxes.asArrayRef();
+    for (int64_t axis : subAxesArray) {
+      if (!visitedAxes.insert(axis).second)
+        return emitError() << "mesh axis duplicated";
+    }
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
new file mode 100644
index 000000000000000..8898ab000e6b6ed
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+
+// expected-error at +1 {{rank of cluster is expected to be a positive integer}}
+mesh.cluster @mesh0(rank = 0)
+
+// -----
+
+// expected-error at +1 {{rank of dim_sizes is not expected to be larger than rank of cluster}}
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 3, 4])
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @mesh_shard_op_stacked_true_true(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = mesh.shard %arg0 to <@mesh0, [[], [0]]> : tensor<4x8xf32>
+  // expected-error at +1 {{two mesh.shard ops with as_result = true are not expected to be stacked together}}
+  %1 = mesh.shard %0 to <@mesh0, [[], [0, 1]]> : tensor<4x8xf32>
+  return %1 : tensor<4x8xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @mesh_shard_op_stacked_false_false(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = mesh.shard %arg0 to <@mesh0, [[], [0]]> {as_result = false} : tensor<4x8xf32>
+  // expected-error at +1 {{two mesh.shard ops with as_result = false are not expected to be stacked together}}
+  %1 = mesh.shard %0 to <@mesh0, [[], [0, 1]]> {as_result = false} : tensor<4x8xf32>
+  return %1 : tensor<4x8xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @mesh_shard_ops_on_same_tensor(%arg0 : tensor<4x8xf32>) -> 
+                                        (tensor<4x8xf32>, tensor<4x8xf32>) {
+  // expected-error at +1 {{when than one mesh.shard ops operate on the same tensor, all of their as_result attributes are expected to be false}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[], [0]]> : tensor<4x8xf32>
+  %1 = mesh.shard %arg0 to <@mesh0, [[], [0, 1]]> : tensor<4x8xf32>
+  return %0, %1 : tensor<4x8xf32>, tensor<4x8xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @mesh_axis_duplicated(
+    // expected-error at +1 {{mesh axis duplicated}}
+    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>) -> 
+            tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>> {
+  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @mesh_axis_duplicated_2(
+    // expected-error at +1 {{mesh axis duplicated}}
+    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>) -> 
+            tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>> {
+  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>
+}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
new file mode 100644
index 000000000000000..c87f07257d66096
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -0,0 +1,69 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// CHECK: mesh.cluster @mesh0
+mesh.cluster @mesh0(rank = 3, dim_sizes = [2, 2, 4])
+
+// CHECK: mesh.cluster @mesh1
+mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
+
+// CHECK: mesh.cluster @mesh2
+mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
+
+// CHECK: mesh.cluster @mesh3
+mesh.cluster @mesh3(rank = 2)
+
+// CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
+func.func @mesh_shard_encoding_fully_replicated(
+    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>) -> 
+            tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>> {
+  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
+}
+
+// CHECK-LABEL: func @mesh_shard_encoding_1st_dim
+func.func @mesh_shard_encoding_1st_dim(
+    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>) -> 
+            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>> {
+  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
+}
+
+// CHECK-LABEL: func @mesh_shard_encoding_2nd_dim
+func.func @mesh_shard_encoding_2nd_dim(
+    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>>) -> 
+    tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>> {
+  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>>
+}
+
+// CHECK-LABEL: func @mesh_shard_encoding_1st_and_3rd_dim
+func.func @mesh_shard_encoding_1st_and_3rd_dim(
+    %arg0 : tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>>) -> 
+            tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>> {
+  return %arg0 : tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_1st_dim
+func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_2nd_dim
+func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  %0 = mesh.shard %arg0 to <@mesh1, [[], [0]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_1st_and_3rd_dim
+func.func @mesh_shard_op_1st_and_3rd_dim(
+    %arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  %0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]]> : tensor<4x8x16xf32>
+  return %0 : tensor<4x8x16xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_two_users
+func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) -> 
+                                  (tensor<4x8xf32>, tensor<4x8xf32>) {
+  %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+  %1 = mesh.shard %0 to <@mesh0, [[1]]> {as_result = false} : tensor<4x8xf32>
+  %2 = mesh.shard %0 to <@mesh0, [[2]]> {as_result = false} : tensor<4x8xf32>
+  return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
+}

>From ded65727a402f7a148b1be9aa0801c25a879f572 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmial.com>
Date: Mon, 2 Oct 2023 17:55:54 +0000
Subject: [PATCH 2/4] small fix

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |  2 +-
 mlir/lib/Dialect/Mesh/CMakeLists.txt          | 16 +---------------
 mlir/lib/Dialect/Mesh/IR/CMakeLists.txt       | 15 +++++++++++++++
 mlir/lib/Dialect/Mesh/{ => IR}/MeshOps.cpp    |  0
 4 files changed, 17 insertions(+), 16 deletions(-)
 create mode 100644 mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
 rename mlir/lib/Dialect/Mesh/{ => IR}/MeshOps.cpp (100%)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 05b5eb75c66d3f2..242d526e42c02a5 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -1,4 +1,4 @@
-//===- MeshBae.td - Mesh Dialect Dialect -------------------*- tablegen -*-===//
+//===- MeshBase.td - Mesh Dialect --------------------------*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Mesh/CMakeLists.txt
index fc3f7c96f05c77f..f33061b2d87cffc 100644
--- a/mlir/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,15 +1 @@
-add_mlir_dialect_library(MLIRMeshDialect
-  MeshOps.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
-
-  DEPENDS
-  MLIRMeshOpsAttrIncGen
-  MLIRMeshOpsIncGen
-
-  LINK_LIBS PUBLIC
-  MLIRArithDialect
-  MLIRIR
-  MLIRSupport
-)
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
new file mode 100644
index 000000000000000..fc3f7c96f05c77f
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIRMeshDialect
+  MeshOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+
+  DEPENDS
+  MLIRMeshOpsAttrIncGen
+  MLIRMeshOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIRIR
+  MLIRSupport
+)
diff --git a/mlir/lib/Dialect/Mesh/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
similarity index 100%
rename from mlir/lib/Dialect/Mesh/MeshOps.cpp
rename to mlir/lib/Dialect/Mesh/IR/MeshOps.cpp

>From ee0179330df02f33310c441d3c40dae615bcc322 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmial.com>
Date: Tue, 3 Oct 2023 15:23:31 +0000
Subject: [PATCH 3/4] fix comments

---
 .../mlir/Dialect/Mesh/IR/CMakeLists.txt       |  5 ++
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td | 54 +++++++++---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   | 12 +--
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  | 88 ++++++++++++++-----
 mlir/lib/Dialect/Mesh/IR/CMakeLists.txt       |  1 +
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 49 ++++-------
 mlir/test/Dialect/Mesh/invalid.mlir           | 44 +++-------
 mlir/test/Dialect/Mesh/ops.mlir               | 54 +++++++++++-
 8 files changed, 196 insertions(+), 111 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
index f1d15891d029e1a..cfc948e305638fa 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
@@ -6,3 +6,8 @@ mlir_tablegen(MeshOpsAttributes.h.inc -gen-attrdef-decls)
 mlir_tablegen(MeshOpsAttributes.cpp.inc -gen-attrdef-defs)
 add_public_tablegen_target(MLIRMeshOpsAttrIncGen)
 add_mlir_doc(MeshOps MeshAttributes Dialects/ -gen-attrdef-doc)
+
+set(LLVM_TARGET_DEFINITIONS MeshBase.td)
+mlir_tablegen(MeshOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(MeshOpsEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRMeshOpsEnumsIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 242d526e42c02a5..4a5feacd110a80e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -12,6 +12,7 @@
 include "mlir/IR/OpBase.td"
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/EnumAttr.td"
 
 //===----------------------------------------------------------------------===//
 // Mesh Dialect
@@ -28,12 +29,25 @@ def Mesh_Dialect : Dialect {
   }];
 
   let dependentDialects = [
-    "arith::ArithDialect"
+    "arith::ArithDialect" // For materializeConstant()
   ];
 
   let useDefaultAttributePrinterParser = 1;
   let hasConstantMaterializer = 1;
 }
+//===----------------------------------------------------------------------===//
+// Mesh Enums.
+//===----------------------------------------------------------------------===//
+
+def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor", [
+  I32EnumAttrCase<"Sum", 0, "partial_sum">,
+  I32EnumAttrCase<"Max", 1, "partial_max">,
+  I32EnumAttrCase<"Min", 2, "partial_min">,
+  I32EnumAttrCase<"Generic", 100, "partial_generic">
+]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::mesh";
+}
 
 //===----------------------------------------------------------------------===//
 // Mesh Attribute
@@ -43,25 +57,39 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
   let mnemonic = "shard";
 
   let parameters = (ins
-    OptionalParameter<"::mlir::SymbolRefAttr">:$cluster,
-    ArrayRefParameter<"::mlir::DenseI64ArrayAttr">:$axes
+    AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
+    ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$axes,
+    DefaultValuedParameter<"::mlir::mesh::Partial", 
+                           "::mlir::mesh::Partial::Sum">:$partial_type
   );
 
   let summary = "Attribute that extends tensor type to distributed tensor type.";
 
   let description = [{
-    The mesh.shard attribute contains two attribute in it:
+    The MeshSharding attribute could be used in the encoding of a
+    `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
     1. `cluster`: this attribute is a SymbolRefAttr that refers to the mesh
-    cluster where the distributed tensor is placed.
+    cluster where the distributed tensor is placed. The symbol must resolve to a
+    `mesh.cluster` operation.
 
     2. `axes`: is an array composed of int64_t sub-arrays. The outer array's
     maximum size is the `rank` of the related tensor plus one. For the i-th
     sub-array, if its value is [x, y]:
     - When i < `rank`, it indicates that the tensor's i-th dimension is sharded
     along the x and y axes of the device mesh.
-    - When i == `rank`, it signifies that the tensor represents a partial sum
-    along the x and y axes. More partial types could be introduced if needed, 
-    e.g. partial-max, partial-min and even partial-generic.
+    - When i == `rank`, it signifies that the tensor is partial one along the x
+    and y axes. An all-reduce should be applied to get the whole tensor, and the
+    reduction type is specified by `partial_type`.
+
+    3. `partial_type`: indicates the reduction type of the possible all-reduce
+    op. It has 4 possible values:
+    - `partial_sum`: denotes it's an all-reduce-sum
+    - `partial_max`: denotes it's an all-reduce-max
+    - `partial_min`: denotes it's an all-reduce-min
+    - `partial_generic`: denotes that the all-reduce type is complex and cannot
+    be represented merely by a simple sum, max, or min. The exact reduction
+    computation may be derived from the semantics of the corresponding operation
+    or from the reduction computation IR
 
     Example:
 
@@ -77,14 +105,18 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
-    // it is also a partial-sum along axis 1.
+    // it is also a partial_sum along mesh axis 1.
     tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>
 
-    // Could also be used in the attribute of mesh.shard op
+    // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
+    // it is also a partial_max along mesh axis 1.
+    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]], partial_max>
+
+    // Could be used in the attribute of mesh.shard op
     %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
     ```
   }];
-  let assemblyFormat = "`<` ($cluster^ `,`)? `[` $axes `]` `>`";
+  let assemblyFormat = "`<` $cluster `,` `[` $axes `]` (`,` $partial_type^)?`>`";
   let genVerifyDecl = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 69159fc0d8cca89..9dfeca84d012165 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -14,18 +14,10 @@
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
-namespace mlir {
-
-namespace mesh {
-
-constexpr StringRef getMeshClusterAttrName() { return "mesh_cluster"; }
-
-} // namespace mesh
-
-} // namespace mlir
-
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
 
+#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index bbd4891a2947517..3acb7cabdc634f6 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -40,9 +40,9 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
 
     3. `dim_sizes`: This attribute represents the device assignment along the
     axes of the cluster. Each integer in the array corresponds to the number of
-    devices along a specific axis. If an integer value is <= 0, it implies that
-    the number of devices along that axis is unknown. This flexibility allows
-    for dynamic device assignment or configurations where the exact number of
+    devices along a specific axis. If an integer value is 0, it implies that the
+    number of devices along that axis is unknown. This flexibility allows for
+    dynamic device assignment or configurations where the exact number of
     devices might not be determined during compile time.
 
     Example:
@@ -63,22 +63,20 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
     // is unknown
     mesh.cluster @mesh3(rank = 2)
 
-    // A func op's default mesh cluster is @mesh0
-    func.func @foo() -> () attributes { mesh_cluster = @mesh0 } {
-      ...
-    }
-
     // Used in the mesh sharding attribute to extend the standard tensor to
     // distributed
-    tensor<4x8xf32, #mesh.shard<[[0]] @mesh0>
+    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
     ```
   }];
   let arguments = (ins
     SymbolNameAttr:$sym_name,
-    I64Attr:$rank,
+    I8Attr:$rank,
     DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
   );
-  let assemblyFormat = "$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)` attr-dict";
+  let assemblyFormat = [{
+    $sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
+      attr-dict
+  }];
   let hasVerifier = 1;
 }
 
@@ -95,38 +93,82 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
     2. `shard`: This attribute is type of `MeshSharding`, which is the core data
     structure to represent distributed tensor in mesh cluster.
 
-    3. `as_result`: A boolean attribute addressing the scenario when a tensor's
-    sharding annotation differs based on its context of use (either as a result
-    or an operand). If true, the sharding applies to the operation that defines
-    the tensor value. If false, the sharding pertains to specific users of the
-    tensor value, indicating how it should be considered when used as an operand
-    in subsequent operations.
+    3. `annotate_for_users`: A unit attribute addressing the scenario when a
+    tensor's sharding annotation differs based on its context of use (either as
+    a result or an operand). If specified, the sharding pertains to specific
+    users of the tensor value, indicating how it should be considered when used
+    as an operand in subsequent operations. If not, the sharding applies to the
+    operation that defines the tensor value.
 
     Example:
     ```
+    func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () {
+      %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+      ...
+    }
+
+    func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () {
+      %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
+      ...
+    }
+
     // The first mesh.shard op applies to %arg0, the second mesh.shard op
     // applies for the operand of op0, the third mesh.shard op applies for the
     // operand of op2
-    func.func @two_users(%arg0 : tensor<4x8xf32>) -> () {
+    func.func @both_result_and_multi_operands_annotated(
+        %arg0 : tensor<4x8xf32>) -> () {
       %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
-      %1 = mesh.shard %0 to <@mesh0, [[1]]> {as_result = false} : tensor<4x8xf32>
-      %2 = mesh.shard %0 to <@mesh0, [[2]]> {as_result = false} : tensor<4x8xf32>
+      %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
+      %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
       "op0"(%1) : ...
       "op1"(%2) : ...
       ...
     }
     ```
+
+    The following usages are undefined:
+    ```
+    func.func @annotate_on_same_result_with_different_sharding(
+        %arg0 : tensor<4x8xf32>) -> () {
+      %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+      %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
+      ...
+    }
+
+    func.func @annotate_on_same_result_same_value_with_different_sharding(
+        %arg0 : tensor<4x8xf32>) -> () {
+      %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
+      %1 = mesh.shard %arg0 to <@mesh0, [[1]]> : tensor<4x8xf32>
+      ...
+    }
+
+    func.func @annotate_on_same_operand_with_different_sharding(
+        %arg0 : tensor<4x8xf32>) -> () {
+      %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
+      %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
+      ...
+    }
+
+    func.func @result_annotated_after_operand(
+        %arg0 : tensor<4x8xf32>) -> () {
+      %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32>
+      %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32>
+      ...
+    }
+    ```
   }];
   let arguments = (ins
     Builtin_RankedTensor:$src,
     MeshSharding:$shard,
-    DefaultValuedAttr<BoolAttr, "true">:$as_result
+    UnitAttr:$annotate_for_users
   );
   let results = (outs
     Builtin_RankedTensor:$result
   );
-  let assemblyFormat = "$src `to` $shard attr-dict `:` type($result)";
-  let hasVerifier = 1;
+  let assemblyFormat = [{
+    $src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`
+      type($result)
+  }];
 }
 
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
index fc3f7c96f05c77f..700e6e21f36b677 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRMeshDialect
 
   DEPENDS
   MLIRMeshOpsAttrIncGen
+  MLIRMeshOpsEnumsIncGen
   MLIRMeshOpsIncGen
 
   LINK_LIBS PUBLIC
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index fd9f332e204de45..5781bc8130aef8a 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 #define DEBUG_TYPE "mesh-ops"
@@ -46,7 +47,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 
 LogicalResult ClusterOp::verify() {
   ArrayRef<int64_t> dimSizes = getDimSizes();
-  size_t rank = getRank();
+  uint8_t rank = getRank();
 
   if (rank == 0)
     return emitOpError("rank of cluster is expected to be a positive integer");
@@ -55,35 +56,10 @@ LogicalResult ClusterOp::verify() {
     return emitOpError(
         "rank of dim_sizes is not expected to be larger than rank of cluster");
 
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// mesh.shard op
-//===----------------------------------------------------------------------===//
-
-LogicalResult ShardOp::verify() {
-  bool asResult = getAsResult();
-  if (asResult) {
-    Value src = getSrc();
-    Operation *defOp = src.getDefiningOp();
-    if (llvm::isa_and_nonnull<ShardOp>(defOp))
-      return emitOpError("two mesh.shard ops with as_result = true are not "
-                         "expected to be stacked together");
-
-    unsigned numShard = llvm::count_if(src.getUsers(), [](Operation *user) {
-      return llvm::isa<ShardOp>(user);
-    });
-    if (numShard > 1)
+  for (int64_t dimSize : dimSizes) {
+    if (dimSize < 0)
       return emitOpError(
-          "when than one mesh.shard ops operate on the same tensor, all of "
-          "their as_result attributes are expected to be false");
-
-  } else {
-    ShardOp defShardOp = getSrc().getDefiningOp<ShardOp>();
-    if (defShardOp && !defShardOp.getAsResult())
-      return emitOpError("two mesh.shard ops with as_result = false are not "
-                         "expected to be stacked together");
+          "dimension size of a mesh cluster is expected to be non-negative");
   }
 
   return success();
@@ -95,14 +71,17 @@ LogicalResult ShardOp::verify() {
 
 LogicalResult
 MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                         SymbolRefAttr, ArrayRef<DenseI64ArrayAttr> axes) {
+                         SymbolRefAttr, ArrayRef<DenseI8ArrayAttr> axes,
+                         Partial) {
   // TODO: At present cluster symbol ref is not verified. This is due to the
   // difficulty in fetching the corresponding symbol op based on an attribute.
 
-  DenseSet<int64_t> visitedAxes;
-  for (DenseI64ArrayAttr subAxes : axes) {
-    ArrayRef<int64_t> subAxesArray = subAxes.asArrayRef();
-    for (int64_t axis : subAxesArray) {
+  llvm::SmallSet<int8_t, 4> visitedAxes;
+  for (DenseI8ArrayAttr subAxes : axes) {
+    ArrayRef<int8_t> subAxesArray = subAxes.asArrayRef();
+    for (int8_t axis : subAxesArray) {
+      if (axis < 0)
+        return emitError() << "mesh axis is expected to be non-negative";
       if (!visitedAxes.insert(axis).second)
         return emitError() << "mesh axis duplicated";
     }
@@ -120,3 +99,5 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
+
+#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.cpp.inc"
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 8898ab000e6b6ed..39a4645176459f9 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -10,37 +10,8 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 3, 4])
 
 // -----
 
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
-
-func.func @mesh_shard_op_stacked_true_true(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  %0 = mesh.shard %arg0 to <@mesh0, [[], [0]]> : tensor<4x8xf32>
-  // expected-error at +1 {{two mesh.shard ops with as_result = true are not expected to be stacked together}}
-  %1 = mesh.shard %0 to <@mesh0, [[], [0, 1]]> : tensor<4x8xf32>
-  return %1 : tensor<4x8xf32>
-}
-
-// -----
-
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
-
-func.func @mesh_shard_op_stacked_false_false(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  %0 = mesh.shard %arg0 to <@mesh0, [[], [0]]> {as_result = false} : tensor<4x8xf32>
-  // expected-error at +1 {{two mesh.shard ops with as_result = false are not expected to be stacked together}}
-  %1 = mesh.shard %0 to <@mesh0, [[], [0, 1]]> {as_result = false} : tensor<4x8xf32>
-  return %1 : tensor<4x8xf32>
-}
-
-// -----
-
-mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
-
-func.func @mesh_shard_ops_on_same_tensor(%arg0 : tensor<4x8xf32>) -> 
-                                        (tensor<4x8xf32>, tensor<4x8xf32>) {
-  // expected-error at +1 {{when than one mesh.shard ops operate on the same tensor, all of their as_result attributes are expected to be false}}
-  %0 = mesh.shard %arg0 to <@mesh0, [[], [0]]> : tensor<4x8xf32>
-  %1 = mesh.shard %arg0 to <@mesh0, [[], [0, 1]]> : tensor<4x8xf32>
-  return %0, %1 : tensor<4x8xf32>, tensor<4x8xf32>
-}
+// expected-error at +1 {{dimension size of a mesh cluster is expected to be non-negative}}
+mesh.cluster @mesh0(rank = 2, dim_sizes = [-1])
 
 // -----
 
@@ -63,3 +34,14 @@ func.func @mesh_axis_duplicated_2(
             tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>> {
   return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>
 }
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @mesh_axis_negtive(
+    // expected-error at +1 {{mesh axis is expected to be non-negative}}
+    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>) -> 
+            tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>> {
+  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>
+}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index c87f07257d66096..a2c18347a1313f2 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -14,6 +14,7 @@ mesh.cluster @mesh3(rank = 2)
 
 // CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
 func.func @mesh_shard_encoding_fully_replicated(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
     %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>) -> 
             tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>> {
   return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
@@ -21,6 +22,7 @@ func.func @mesh_shard_encoding_fully_replicated(
 
 // CHECK-LABEL: func @mesh_shard_encoding_1st_dim
 func.func @mesh_shard_encoding_1st_dim(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}0]]>>
     %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>) -> 
             tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>> {
   return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
@@ -28,6 +30,7 @@ func.func @mesh_shard_encoding_1st_dim(
 
 // CHECK-LABEL: func @mesh_shard_encoding_2nd_dim
 func.func @mesh_shard_encoding_2nd_dim(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh1, {{\[\[}}], [0]]>>
     %arg0 : tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>>) -> 
     tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>> {
   return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>>
@@ -35,35 +38,82 @@ func.func @mesh_shard_encoding_2nd_dim(
 
 // CHECK-LABEL: func @mesh_shard_encoding_1st_and_3rd_dim
 func.func @mesh_shard_encoding_1st_and_3rd_dim(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32, #mesh.shard<@mesh3, {{\[\[}}0], [], [1]]>>
     %arg0 : tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>>) -> 
             tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>> {
   return %arg0 : tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>>
 }
 
 // CHECK-LABEL: func @mesh_shard_op_1st_dim
+// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
 func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}0]]> : tensor<4x8xf32>
   %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
 // CHECK-LABEL: func @mesh_shard_op_2nd_dim
+// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
 func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh1, {{\[\[}}], [0]]> : tensor<4x8xf32>
   %0 = mesh.shard %arg0 to <@mesh1, [[], [0]]> : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
 // CHECK-LABEL: func @mesh_shard_op_1st_and_3rd_dim
 func.func @mesh_shard_op_1st_and_3rd_dim(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32>
     %arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0], [], [1]]> : tensor<4x8x16xf32>
   %0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]]> : tensor<4x8x16xf32>
   return %0 : tensor<4x8x16xf32>
 }
 
+// CHECK-LABEL: func @mesh_shard_op_partial_max
+func.func @mesh_shard_op_partial_max(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0], [], [1]], partial_max> : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]], partial_max> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_partial_min
+func.func @mesh_shard_op_partial_min(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0], [], [1]], partial_min> : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]], partial_min> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_partial_generic
+func.func @mesh_shard_op_partial_generic(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0], [], [1]], partial_generic> : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]], partial_generic> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_partial_sum_explict
+func.func @mesh_shard_op_partial_sum_explict(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0], [], [1]]> : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]], partial_sum> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
+}
+
 // CHECK-LABEL: func @mesh_shard_op_two_users
+// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
 func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) -> 
                                   (tensor<4x8xf32>, tensor<4x8xf32>) {
+  // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}0]]> : tensor<4x8xf32>                  
   %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
-  %1 = mesh.shard %0 to <@mesh0, [[1]]> {as_result = false} : tensor<4x8xf32>
-  %2 = mesh.shard %0 to <@mesh0, [[2]]> {as_result = false} : tensor<4x8xf32>
+  // CHECK-DAG: mesh.shard %[[V0]] to <@mesh0, {{\[\[}}1]]> annotate_for_users : tensor<4x8xf32>
+  %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
+  // CHECK-DAG: mesh.shard %[[V0]] to <@mesh0, {{\[\[}}2]]> annotate_for_users : tensor<4x8xf32>
+  %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
   return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
 }

>From 70c92d4e11b758841f819df4a3100b854ddaadf5 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmial.com>
Date: Tue, 3 Oct 2023 21:46:43 +0000
Subject: [PATCH 4/4] fix comments 2nd time

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td | 39 +++++++++++--------
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  6 +--
 mlir/include/mlir/IR/OpImplementation.h       | 14 ++++++-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 19 ++++++---
 mlir/test/Dialect/Mesh/invalid.mlir           | 28 +++++++++++--
 mlir/test/Dialect/Mesh/ops.mlir               | 29 +++++++++-----
 6 files changed, 96 insertions(+), 39 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 4a5feacd110a80e..dc13923307d604b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -40,10 +40,10 @@ def Mesh_Dialect : Dialect {
 //===----------------------------------------------------------------------===//
 
 def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor", [
-  I32EnumAttrCase<"Sum", 0, "partial_sum">,
-  I32EnumAttrCase<"Max", 1, "partial_max">,
-  I32EnumAttrCase<"Min", 2, "partial_min">,
-  I32EnumAttrCase<"Generic", 100, "partial_generic">
+  I32EnumAttrCase<"Sum", 1, "sum">,
+  I32EnumAttrCase<"Max", 2, "max">,
+  I32EnumAttrCase<"Min", 3, "min">,
+  I32EnumAttrCase<"Generic", 100, "generic">
 ]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::mesh";
@@ -58,9 +58,9 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
   let parameters = (ins
     AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
-    ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$axes,
-    DefaultValuedParameter<"::mlir::mesh::Partial", 
-                           "::mlir::mesh::Partial::Sum">:$partial_type
+    ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$split_axes,
+    OptionalArrayRefParameter<"int8_t">:$partial_axes,
+    OptionalParameter<"::mlir::mesh::Partial">:$partial_type
   );
 
   let summary = "Attribute that extends tensor type to distributed tensor type.";
@@ -68,20 +68,21 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
   let description = [{
     The MeshSharding attribute could be used in the encoding of a
     `RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
+
     1. `cluster`: this attribute is a SymbolRefAttr that refers to the mesh
     cluster where the distributed tensor is placed. The symbol must resolve to a
     `mesh.cluster` operation.
 
-    2. `axes`: is an array composed of int64_t sub-arrays. The outer array's
-    maximum size is the `rank` of the related tensor plus one. For the i-th
-    sub-array, if its value is [x, y]:
-    - When i < `rank`, it indicates that the tensor's i-th dimension is sharded
+    2. `split_axes`: is an array composed of int64_t sub-arrays. The outer array's
+    maximum size is the `rank` of the related tensor. For the i-th sub-array, if
+    its value is [x, y], it indicates that the tensor's i-th dimension is splitted
     along the x and y axes of the device mesh.
-    - When i == `rank`, it signifies that the tensor is partial one along the x
-    and y axes. An all-reduce should be applied to get the whole tensor, and the
-    reduction type is specified by `partial_type`.
 
-    3. `partial_type`: indicates the reduction type of the possible all-reduce
+    3. `partial_axes`: if not empty, this sifnifies that the tensor is partial
+    one along the specified mesh axes. An all-reduce should be applied to obtain
+    the complete tensor, with reduction type being specified by `partial_type`.
+
+    4. `partial_type`: indicates the reduction type of the possible all-reduce
     op. It has 4 possible values:
     - `partial_sum`: denotes it's an all-reduce-sum
     - `partial_max`: denotes it's an all-reduce-max
@@ -110,13 +111,17 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_max along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]], partial_max>
+    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial = max[1]>
 
     // Could be used in the attribute of mesh.shard op
     %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
     ```
   }];
-  let assemblyFormat = "`<` $cluster `,` `[` $axes `]` (`,` $partial_type^)?`>`";
+  let assemblyFormat = [{
+    `<` $cluster `,` `[` $split_axes `]` (`,` `partial` `=` $partial_type `[`
+       $partial_axes^ `]`)? `>`
+  }];
+
   let genVerifyDecl = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 3acb7cabdc634f6..8ca4b6653104221 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -47,15 +47,15 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
 
     Example:
     ```
-    // A device mesh cluster with 3 axes, the totol device number is 4 * 8 * 12
+    // A device mesh cluster with 3 axes, the total device number is 4 * 8 * 12
     // The dimension sizes are 4, 8, 12 
     mesh.cluster @mesh0(rank = 3, dim_sizes = [4, 8, 12])
 
-    // A device mesh cluster with 2 axes, the totol device number is unknown
+    // A device mesh cluster with 2 axes, the total device number is unknown
     // The first dimension size is 4 and the second is unknown
     mesh.cluster @mesh1(rank = 2, dim_sizes = [4])
 
-    // A device mesh cluster with 2 axes, the totol device number is unknown
+    // A device mesh cluster with 2 axes, the total device number is unknown
     // The first dimension size is unknown and the second is 4
     mesh.cluster @mesh2(rank = 2, dim_sizes = [0, 4])
 
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index f1fabf95a68b7ad..379392ace46961a 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -350,7 +350,8 @@ template <typename AsmPrinterT, typename T,
                                !std::is_convertible<T &, Attribute &>::value &&
                                !std::is_convertible<T &, ValueRange>::value &&
                                !std::is_convertible<T &, APFloat &>::value &&
-                               !llvm::is_one_of<T, bool, float, double>::value,
+                               !llvm::is_one_of<T, bool, int8_t, uint8_t, float,
+                                                double>::value,
                            T> * = nullptr>
 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
                         AsmPrinterT &>
@@ -366,6 +367,17 @@ operator<<(AsmPrinterT &p, bool value) {
   return p << (value ? StringRef("true") : "false");
 }
 
+/// Specialization for 8-bit integers to ensure values are printed as integers
+// and not characters.
+template <
+    typename AsmPrinterT, typename T,
+    std::enable_if_t<llvm::is_one_of<T, int8_t, uint8_t>::value, T> * = nullptr>
+inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
+                        AsmPrinterT &>
+operator<<(AsmPrinterT &p, T value) {
+  return p << static_cast<int16_t>(value);
+}
+
 template <typename AsmPrinterT, typename ValueRangeT>
 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
                         AsmPrinterT &>
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 5781bc8130aef8a..b2a47102528758c 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -71,21 +71,30 @@ LogicalResult ClusterOp::verify() {
 
 LogicalResult
 MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                         SymbolRefAttr, ArrayRef<DenseI8ArrayAttr> axes,
-                         Partial) {
+                         SymbolRefAttr, ArrayRef<DenseI8ArrayAttr> splitAxes,
+                         ArrayRef<int8_t> partialAxes, Partial) {
   // TODO: At present cluster symbol ref is not verified. This is due to the
   // difficulty in fetching the corresponding symbol op based on an attribute.
 
   llvm::SmallSet<int8_t, 4> visitedAxes;
-  for (DenseI8ArrayAttr subAxes : axes) {
-    ArrayRef<int8_t> subAxesArray = subAxes.asArrayRef();
-    for (int8_t axis : subAxesArray) {
+
+  auto checkMeshAxis = [&](ArrayRef<int8_t> axesArray) -> LogicalResult {
+    for (int8_t axis : axesArray) {
       if (axis < 0)
         return emitError() << "mesh axis is expected to be non-negative";
       if (!visitedAxes.insert(axis).second)
         return emitError() << "mesh axis duplicated";
     }
+    return success();
+  };
+
+  for (DenseI8ArrayAttr subAxes : splitAxes) {
+    ArrayRef<int8_t> subAxesArray = subAxes.asArrayRef();
+    if (failed(checkMeshAxis(subAxesArray)))
+      return failure();
   }
+  if (failed(checkMeshAxis(partialAxes)))
+    return failure();
 
   return success();
 }
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 39a4645176459f9..246439dd4be7122 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -17,7 +17,7 @@ mesh.cluster @mesh0(rank = 2, dim_sizes = [-1])
 
 mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
 
-func.func @mesh_axis_duplicated(
+func.func @mesh_axis_duplicated_different_subarray(
     // expected-error at +1 {{mesh axis duplicated}}
     %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>) -> 
             tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>> {
@@ -28,7 +28,7 @@ func.func @mesh_axis_duplicated(
 
 mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
 
-func.func @mesh_axis_duplicated_2(
+func.func @mesh_axis_duplicated_same_subarray(
     // expected-error at +1 {{mesh axis duplicated}}
     %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>) -> 
             tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>> {
@@ -39,9 +39,31 @@ func.func @mesh_axis_duplicated_2(
 
 mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
 
-func.func @mesh_axis_negtive(
+func.func @mesh_axis_duplicated_bewteen_split_and_partial(
+    // expected-error at +1 {{mesh axis duplicated}}
+    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>>) -> 
+            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>> {
+  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @mesh_axis_negtive_in_split_part(
     // expected-error at +1 {{mesh axis is expected to be non-negative}}
     %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>) -> 
             tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>> {
   return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>
 }
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 4])
+
+func.func @mesh_axis_negtive_in_partial(
+    // expected-error at +1 {{mesh axis is expected to be non-negative}}
+    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>) -> 
+            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
+  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
+}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index a2c18347a1313f2..ee5f8f67792b928 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -73,8 +73,8 @@ func.func @mesh_shard_op_1st_and_3rd_dim(
 func.func @mesh_shard_op_partial_max(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0], [], [1]], partial_max> : tensor<4x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]], partial_max> : tensor<4x8xf32>
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = max[1]> : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = max[1]> : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
@@ -82,8 +82,8 @@ func.func @mesh_shard_op_partial_max(
 func.func @mesh_shard_op_partial_min(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0], [], [1]], partial_min> : tensor<4x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]], partial_min> : tensor<4x8xf32>
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = min[1]> : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = min[1]> : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
@@ -91,17 +91,26 @@ func.func @mesh_shard_op_partial_min(
 func.func @mesh_shard_op_partial_generic(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0], [], [1]], partial_generic> : tensor<4x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]], partial_generic> : tensor<4x8xf32>
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = generic[1]> : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = generic[1]> : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 
-// CHECK-LABEL: func @mesh_shard_op_partial_sum_explict
-func.func @mesh_shard_op_partial_sum_explict(
+// CHECK-LABEL: func @mesh_shard_op_partial_sum
+func.func @mesh_shard_op_partial_sum(
     // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
     %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0], [], [1]]> : tensor<4x8xf32>
-  %0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]], partial_sum> : tensor<4x8xf32>
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = sum[1]> : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = sum[1]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
+}
+
+// CHECK-LABEL: func @mesh_shard_op_partial_sum_multi_axes
+func.func @mesh_shard_op_partial_sum_multi_axes(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = sum[1, 2]> : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = sum[1, 2]> : tensor<4x8xf32>
   return %0 : tensor<4x8xf32>
 }
 



More information about the Mlir-commits mailing list