[Mlir-commits] [mlir] [mlir][mesh] Refactoring code organization, tests and docs (PR #79606)

Boian Petkantchin llvmlistbot at llvm.org
Tue Jan 30 08:26:54 PST 2024


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/79606

>From 121bdc353a1afe4b493e61117f17f3a30f49b0dc Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 25 Jan 2024 15:36:17 -0800
Subject: [PATCH 1/3] [mlir][mesh] Refactoring code organization, tests and
 docs

* Split out `MeshDialect.h` form `MeshOps.h` that defines the dialect class.
Reduces include clutter if you care only about the dialect and not the ops.

* Expose functions `getMesh` and `collectiveProcessGroupSize`.
There functions are useful for outside users of the dialect.

* Remove unused code.

* Remove examples and tests of mesh.shard attribute in tensor encoding.
Per the decision that Spmdization would be performed on sharding annotations
and there will be no tensors with sharding specified in the type.
For more info see this RFC comment:
https://discourse.llvm.org/t/rfc-sharding-framework-design-for-device-mesh/73533/81
---
 .../mlir/Dialect/Mesh/IR/CMakeLists.txt       | 12 ++--
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |  8 +--
 .../mlir/Dialect/Mesh/IR/MeshDialect.h        | 16 +++++
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   | 37 +++++++++--
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  4 --
 mlir/include/mlir/InitAllDialects.h           |  2 +-
 mlir/lib/Dialect/Mesh/IR/CMakeLists.txt       |  4 +-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 62 +++++--------------
 .../Mesh/Transforms/ShardingPropagation.cpp   |  1 +
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 20 ++----
 .../Dialect/Mesh/Transforms/Transforms.cpp    |  1 +
 mlir/test/Dialect/Mesh/invalid.mlir           | 40 ++++++------
 mlir/test/Dialect/Mesh/ops.mlir               | 36 ++---------
 .../Mesh/TestProcessMultiIndexOpLowering.cpp  |  2 -
 .../lib/Dialect/Mesh/TestSimplifications.cpp  |  2 +-
 15 files changed, 108 insertions(+), 139 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
index 79a99e58b4e8f..a73ec701dbadc 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
@@ -2,11 +2,11 @@ add_mlir_dialect(MeshOps mesh)
 add_mlir_doc(MeshOps MeshOps Dialects/ -gen-dialect-doc -dialect=mesh)
 
 set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshOpsAttributes.h.inc -gen-attrdef-decls)
-mlir_tablegen(MeshOpsAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRMeshOpsAttrIncGen)
+mlir_tablegen(MeshAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(MeshAttributes.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRMeshAttrIncGen)
 
 set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshOpsEnums.h.inc -gen-enum-decls)
-mlir_tablegen(MeshOpsEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRMeshOpsEnumsIncGen)
+mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
+mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRMeshEnumsIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index e8353613cd0e1..04929f4869273 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -123,18 +123,18 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     // The tensor is fully replicated on @mesh0.
     // Currently, there must be at least one sub-array present in axes, even
     // if it's empty. Otherwise, a parsing error will occur.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
+    #mesh.shard<@mesh0, [[]]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>
+    #mesh.shard<@mesh0, [[0]]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_sum along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
+    #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_max along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial = max[1]>
+    #mesh.shard<@mesh0, [[0]], partial = max[1]>
 
     // Could be used in the attribute of mesh.shard op
     %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
new file mode 100644
index 0000000000000..a1adbaa44406d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
@@ -0,0 +1,16 @@
+//===- MeshOps.h - Mesh Dialect ---------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_IR_MESHDIALECT_H
+#define MLIR_DIALECT_MESH_IR_MESHDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
+
+#endif // MLIR_DIALECT_MESH_IR_MESHDIALECT_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 83452dcc2e8ab..8e5e0f541ba5e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -15,7 +15,6 @@
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include <algorithm>
 
 namespace mlir {
 namespace mesh {
@@ -26,12 +25,10 @@ using MeshAxesAttr = DenseI16ArrayAttr;
 } // namespace mesh
 } // namespace mlir
 
-#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
-
-#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
+#include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
 
 #define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.h.inc"
+#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
@@ -51,6 +48,36 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
 
 Partial getPartialTypeFromReduction(IteratorType iType);
 
+inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+                            SymbolTableCollection &symbolTableCollection) {
+  return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
+      op, meshSymbol);
+}
+
+// Get the corresponding mesh op using the standard attribute nomenclature.
+template <typename Op>
+mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
+  return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
+}
+
+// Get the number of processes that participate in each group
+// induced by `meshAxes`.
+template <typename MeshAxesRange, typename MeshShapeRange>
+int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
+                                   MeshShapeRange &&meshShape) {
+  int64_t res = 1;
+
+  for (MeshAxis axis : meshAxes) {
+    auto axisSize = *(std::begin(meshShape) + axis);
+    if (ShapedType::isDynamic(axisSize)) {
+      return ShapedType::kDynamic;
+    }
+    res *= axisSize;
+  }
+
+  return res;
+}
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 7b301025e687a..c9efdcc3a68bb 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -61,10 +61,6 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
     // A device mesh with 2 axes, the number of devices along both axes
     // is unknown
     mesh.mesh @mesh3(shape = ?x?)
-
-    // Used in the mesh sharding attribute to extend the standard tensor to
-    // distributed
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
     ```
   }];
   let arguments = (ins
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 19a62cadaa2e0..d1ee32d7bac61 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -54,7 +54,7 @@
 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
index 634a94f8cec87..140fb553f8481 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -5,8 +5,8 @@ add_mlir_dialect_library(MLIRMeshDialect
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
 
   DEPENDS
-  MLIRMeshOpsAttrIncGen
-  MLIRMeshOpsEnumsIncGen
+  MLIRMeshAttrIncGen
+  MLIRMeshEnumsIncGen
   MLIRMeshOpsIncGen
 
   LINK_LIBS PUBLIC
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 994a017a1f46c..b1cc36d1879e1 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
+
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Attributes.h"
@@ -43,24 +45,6 @@ using namespace mlir::mesh;
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
 
-template <typename It>
-static It canonicalizeSetAsArray(It begin, It end) {
-  llvm::sort(begin, end);
-  return std::unique(begin, end);
-}
-
-template <typename R>
-static auto canonicalizeSetAsArray(R &&range) {
-  return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
-}
-
-template <typename T>
-static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
-  auto newEnd = canonicalizeSetAsArray(vec);
-  vec.resize(newEnd - vec.begin());
-  return vec;
-}
-
 namespace {
 
 struct DimensionSize {
@@ -114,10 +98,10 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 // Mesh utilities
 //===----------------------------------------------------------------------===//
 
-static FailureOr<MeshOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
-                                 SymbolTableCollection &symbolTable) {
-  mesh::MeshOp mesh =
-      symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(op, meshSymbol);
+static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
+                                          FlatSymbolRefAttr meshSymbol,
+                                          SymbolTableCollection &symbolTable) {
+  mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
   if (!mesh) {
     return op->emitError() << "Undefined required mesh symbol \""
                            << meshSymbol.getValue() << "\".";
@@ -201,10 +185,6 @@ LogicalResult MeshOp::verify() {
   if (rank <= 0)
     return emitOpError("rank of mesh is expected to be a positive integer");
 
-  if (getShape().size() > size_t(rank))
-    return emitOpError(
-        "rank of shape is not expected to be larger than rank of mesh");
-
   for (int64_t dimSize : getShape()) {
     if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
       return emitOpError("dimension size of a mesh is expected to be "
@@ -220,7 +200,7 @@ LogicalResult MeshOp::verify() {
 
 LogicalResult
 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -322,7 +302,7 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
 
 LogicalResult
 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -360,7 +340,7 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
 
 LogicalResult
 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -428,7 +408,8 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
 template <typename Op>
 static FailureOr<MeshOp>
 getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
+  auto mesh =
+      ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -450,21 +431,6 @@ static auto product(R &&range) {
   return product(adl_begin(range), adl_end(range));
 }
 
-static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
-                                         ArrayRef<int64_t> meshShape) {
-  int64_t res = 1;
-
-  for (MeshAxis axis : meshAxes) {
-    if (ShapedType::isDynamic(meshShape[axis])) {
-      return ShapedType::kDynamic;
-    }
-    assert(size_t(axis) < meshShape.size());
-    res *= meshShape[axis];
-  }
-
-  return res;
-}
-
 static LogicalResult verifyDimensionCompatibility(Location loc,
                                                   int64_t expectedDimSize,
                                                   int64_t resultDimSize,
@@ -495,7 +461,7 @@ static LogicalResult verifyGatherOperandAndResultShape(
   ShapedType operandType = operand.getType().cast<ShapedType>();
   ShapedType resultType = result.getType().cast<ShapedType>();
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
     auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
     auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
@@ -529,7 +495,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
   }
 
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
   auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
   DimensionSize expectedResultConcatDimSize =
@@ -570,7 +536,7 @@ static LogicalResult verifyScatterOperandAndResultShape(
   }
 
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   auto operandScatterDimSize =
       DimensionSize(operandType.getDimSize(scatterAxis));
   if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 3aed912fb43c6..f3cd12f38879d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Mesh/Transforms/Passes.h"
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
 #include "mlir/Pass/Pass.h"
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 593158d5f6d29..5554edac4d2f6 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -54,19 +55,6 @@ int64_t unshardDimension(int64_t dim, int64_t shardCount) {
   return dim * shardCount;
 }
 
-template <typename MeshShape, typename SplitAxes>
-int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) {
-  int64_t res = 1;
-  for (auto splitAxis : splitAxes) {
-    int64_t meshDimSize = meshShape[splitAxis];
-    if (ShapedType::isDynamic(meshDimSize)) {
-      return ShapedType::kDynamic;
-    }
-    res *= meshDimSize;
-  }
-  return res;
-}
-
 // Compute the shape for the tensor on each device in the mesh.
 // Example:
 // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
@@ -78,9 +66,9 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
   for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
-    outShape[tensorAxis] =
-        shardDimension(inShape[tensorAxis],
-                       shardCount(meshShape, innerSplitAxes.asArrayRef()));
+    outShape[tensorAxis] = shardDimension(
+        inShape[tensorAxis],
+        collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
   }
 }
 
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index 5c2344651bf60..03b1d9b349802 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectRegistry.h"
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 259e4ebf76757..3fa3ebd67b15e 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -13,10 +13,10 @@ mesh.mesh @mesh0(shape = -1)
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_different_subarray(
-    // expected-error at +1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0], [0]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -24,10 +24,10 @@ func.func @mesh_axis_duplicated_different_subarray(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_same_subarray(
-    // expected-error at +1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0, 0]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -35,10 +35,10 @@ func.func @mesh_axis_duplicated_same_subarray(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_bewteen_split_and_partial(
-    // expected-error at +1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0]], partial=max[0]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -46,10 +46,10 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_split_part(
-    // expected-error at +1 {{mesh axis is expected to be non-negative}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis is expected to be non-negative}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[-1]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -57,10 +57,10 @@ func.func @mesh_axis_negtive_in_split_part(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_partial(
-    // expected-error at +1 {{mesh axis is expected to be non-negative}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis is expected to be non-negative}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0]], partial=max[-1]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index dbaaff9c172fd..40a8469b26464 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -17,36 +17,12 @@ mesh.mesh @mesh4(shape = 3)
 // CHECK: mesh.mesh @mesh5(shape = ?)
 mesh.mesh @mesh5(shape = ?)
 
-// CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
-func.func @mesh_shard_encoding_fully_replicated(
-    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
-}
-
-// CHECK-LABEL: func @mesh_shard_encoding_1st_dim
-func.func @mesh_shard_encoding_1st_dim(
-    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}0]]>>
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
-}
-
-// CHECK-LABEL: func @mesh_shard_encoding_2nd_dim
-func.func @mesh_shard_encoding_2nd_dim(
-    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh1, {{\[\[}}], [0]]>>
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>>) -> 
-    tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>>
-}
-
-// CHECK-LABEL: func @mesh_shard_encoding_1st_and_3rd_dim
-func.func @mesh_shard_encoding_1st_and_3rd_dim(
-    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32, #mesh.shard<@mesh3, {{\[\[}}0], [], [1]]>>
-    %arg0 : tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>>) -> 
-            tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>> {
-  return %arg0 : tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>>
+// CHECK-LABEL: func @mesh_shard_op_fully_replicated
+// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+func.func @mesh_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}]]> : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh0, [[]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // CHECK-LABEL: func @mesh_shard_op_1st_dim
diff --git a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
index 5018a308a3bdc..0bcc403a2734e 100644
--- a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
@@ -7,7 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/SymbolTable.h"
@@ -24,7 +23,6 @@ struct TestMultiIndexOpLoweringPass
 
   void runOnOperation() override;
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<mesh::MeshDialect>();
     mesh::processMultiIndexOpLoweringRegisterDialects(registry);
   }
   StringRef getArgument() const final {
diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
index cd22530d51a75..512b16af64c94 100644
--- a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"

>From 3dec93e25f9cceabc236add1c350463d7d672520 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 26 Jan 2024 09:01:23 -0800
Subject: [PATCH 2/3] Fix includes

---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b1cc36d1879e1..79cf5a7c591a3 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -85,7 +85,7 @@ void MeshDialect::initialize() {
       >();
   addAttributes<
 #define GET_ATTRDEF_LIST
-#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
+#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
       >();
 }
 
@@ -812,6 +812,6 @@ void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
+#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
 
-#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.cpp.inc"
+#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"

>From 7664c78ead6fb5a32fd7acb14e0346a5a896d91c Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Tue, 30 Jan 2024 08:24:03 -0800
Subject: [PATCH 3/3] Change naming of internal CMake targets

---
 mlir/docs/Dialects/Mesh.md                      |  3 +--
 .../include/mlir/Dialect/Mesh/IR/CMakeLists.txt | 17 +++++++++++++----
 mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h |  2 +-
 mlir/lib/Dialect/Mesh/IR/CMakeLists.txt         |  4 +---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp            |  4 ++--
 5 files changed, 18 insertions(+), 12 deletions(-)

diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
index 77da2f10d8902..4b7eb90324763 100644
--- a/mlir/docs/Dialects/Mesh.md
+++ b/mlir/docs/Dialects/Mesh.md
@@ -47,11 +47,10 @@ For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
 an in-group device with `(i, j)`. Then for each group with index `g` on the
 second axis, the in-group device would be `(i, g, j)`.
 
-
 ## Operations
 
 [include "Dialects/MeshOps.md"]
 
 ## Attributes
 
-[include "Dialects/MeshAttributes.md"]
+[include "Dialects/MeshAttrs.md"]
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
index a73ec701dbadc..7ba966d8cab7c 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
@@ -1,12 +1,21 @@
-add_mlir_dialect(MeshOps mesh)
-add_mlir_doc(MeshOps MeshOps Dialects/ -gen-dialect-doc -dialect=mesh)
+add_mlir_doc(MeshOps MeshOps Dialects/ -gen-op-doc -dialect=mesh)
+add_mlir_doc(MeshOps MeshAttrs Dialects/ -gen-attrdef-doc -dialect=mesh)
+
+set(LLVM_TARGET_DEFINITIONS MeshOps.td)
+mlir_tablegen(MeshDialect.cpp.inc -gen-dialect-defs -dialect=mesh)
+mlir_tablegen(MeshDialect.h.inc -gen-dialect-decls -dialect=mesh)
 
 set(LLVM_TARGET_DEFINITIONS MeshBase.td)
 mlir_tablegen(MeshAttributes.h.inc -gen-attrdef-decls)
 mlir_tablegen(MeshAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRMeshAttrIncGen)
 
 set(LLVM_TARGET_DEFINITIONS MeshBase.td)
 mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
 mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRMeshEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS MeshOps.td)
+mlir_tablegen(MeshOps.h.inc -gen-op-decls)
+mlir_tablegen(MeshOps.cpp.inc -gen-op-defs)
+
+add_public_tablegen_target(MLIRMeshIncGen)
+add_dependencies(mlir-headers MLIRMeshIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
index a1adbaa44406d..a30cf91e851fe 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
@@ -11,6 +11,6 @@
 
 #include "mlir/IR/Dialect.h"
 
-#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h.inc"
 
 #endif // MLIR_DIALECT_MESH_IR_MESHDIALECT_H
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
index 140fb553f8481..678a25f1c3cf5 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -5,9 +5,7 @@ add_mlir_dialect_library(MLIRMeshDialect
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
 
   DEPENDS
-  MLIRMeshAttrIncGen
-  MLIRMeshEnumsIncGen
-  MLIRMeshOpsIncGen
+  MLIRMeshIncGen
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 79cf5a7c591a3..7c6a5e02ed583 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -7,9 +7,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
-#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -43,7 +43,7 @@
 using namespace mlir;
 using namespace mlir::mesh;
 
-#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
 
 namespace {
 



More information about the Mlir-commits mailing list