[Mlir-commits] [mlir] 31fc0a1 - [mlir][mesh] Refactoring code organization, tests and docs (#79606)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 31 07:20:21 PST 2024


Author: Boian Petkantchin
Date: 2024-01-31T07:20:14-08:00
New Revision: 31fc0a12e1552e6bcea63ae740f284eaf74f4c17

URL: https://github.com/llvm/llvm-project/commit/31fc0a12e1552e6bcea63ae740f284eaf74f4c17
DIFF: https://github.com/llvm/llvm-project/commit/31fc0a12e1552e6bcea63ae740f284eaf74f4c17.diff

LOG: [mlir][mesh] Refactoring code organization, tests and docs (#79606)

* Split out `MeshDialect.h` form `MeshOps.h` that defines the dialect
class. Reduces include clutter if you care only about the dialect and
not the ops.

* Expose functions `getMesh` and `collectiveProcessGroupSize`. There
functions are useful for outside users of the dialect.

* Remove unused code.

* Remove examples and tests of mesh.shard attribute in tensor encoding.
Per the decision that Spmdization would be performed on sharding
annotations and there will be no tensors with sharding specified in the
type.
For more info see this RFC comment:
https://discourse.llvm.org/t/rfc-sharding-framework-design-for-device-mesh/73533/81

Added: 
    mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h

Modified: 
    mlir/docs/Dialects/Mesh.md
    mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
    mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
    mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
    mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
    mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
    mlir/test/Dialect/Mesh/invalid.mlir
    mlir/test/Dialect/Mesh/ops.mlir
    mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
    mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
index 6df3f53c9f017..5eb6569c7044b 100644
--- a/mlir/docs/Dialects/Mesh.md
+++ b/mlir/docs/Dialects/Mesh.md
@@ -46,7 +46,6 @@ These are the axes specified by `mesh_axes` attribute.
 For Example on a 3D mesh an operation with `mesh_axes = [0, 2]` would specify
 an in-group device with `(i, j)`. Then for each group with index `g` on the
 second axis, the in-group device would be `(i, g, j)`.
-
 ### Purity
 Collectives that involve the whole device group to perform a single operation
 are pure. The exceptions are `send` and `recv`.
@@ -72,4 +71,4 @@ passes like dead code and common sub-expression elimination.
 
 ## Attributes
 
-[include "Dialects/MeshAttributes.md"]
+[include "Dialects/MeshAttrs.md"]

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
index 79a99e58b4e8f..7ba966d8cab7c 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
@@ -1,12 +1,21 @@
-add_mlir_dialect(MeshOps mesh)
-add_mlir_doc(MeshOps MeshOps Dialects/ -gen-dialect-doc -dialect=mesh)
+add_mlir_doc(MeshOps MeshOps Dialects/ -gen-op-doc -dialect=mesh)
+add_mlir_doc(MeshOps MeshAttrs Dialects/ -gen-attrdef-doc -dialect=mesh)
+
+set(LLVM_TARGET_DEFINITIONS MeshOps.td)
+mlir_tablegen(MeshDialect.cpp.inc -gen-dialect-defs -dialect=mesh)
+mlir_tablegen(MeshDialect.h.inc -gen-dialect-decls -dialect=mesh)
 
 set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshOpsAttributes.h.inc -gen-attrdef-decls)
-mlir_tablegen(MeshOpsAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRMeshOpsAttrIncGen)
+mlir_tablegen(MeshAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(MeshAttributes.cpp.inc -gen-attrdef-defs)
 
 set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshOpsEnums.h.inc -gen-enum-decls)
-mlir_tablegen(MeshOpsEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRMeshOpsEnumsIncGen)
+mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
+mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
+
+set(LLVM_TARGET_DEFINITIONS MeshOps.td)
+mlir_tablegen(MeshOps.h.inc -gen-op-decls)
+mlir_tablegen(MeshOps.cpp.inc -gen-op-defs)
+
+add_public_tablegen_target(MLIRMeshIncGen)
+add_dependencies(mlir-headers MLIRMeshIncGen)

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index e8353613cd0e1..04929f4869273 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -123,18 +123,18 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     // The tensor is fully replicated on @mesh0.
     // Currently, there must be at least one sub-array present in axes, even
     // if it's empty. Otherwise, a parsing error will occur.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
+    #mesh.shard<@mesh0, [[]]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>
+    #mesh.shard<@mesh0, [[0]]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_sum along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
+    #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_max along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial = max[1]>
+    #mesh.shard<@mesh0, [[0]], partial = max[1]>
 
     // Could be used in the attribute of mesh.shard op
     %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
new file mode 100644
index 0000000000000..a30cf91e851fe
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
@@ -0,0 +1,16 @@
+//===- MeshOps.h - Mesh Dialect ---------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_IR_MESHDIALECT_H
+#define MLIR_DIALECT_MESH_IR_MESHDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h.inc"
+
+#endif // MLIR_DIALECT_MESH_IR_MESHDIALECT_H

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 83452dcc2e8ab..8e5e0f541ba5e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -15,7 +15,6 @@
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include <algorithm>
 
 namespace mlir {
 namespace mesh {
@@ -26,12 +25,10 @@ using MeshAxesAttr = DenseI16ArrayAttr;
 } // namespace mesh
 } // namespace mlir
 
-#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
-
-#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
+#include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
 
 #define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.h.inc"
+#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
@@ -51,6 +48,36 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
 
 Partial getPartialTypeFromReduction(IteratorType iType);
 
+inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+                            SymbolTableCollection &symbolTableCollection) {
+  return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
+      op, meshSymbol);
+}
+
+// Get the corresponding mesh op using the standard attribute nomenclature.
+template <typename Op>
+mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
+  return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
+}
+
+// Get the number of processes that participate in each group
+// induced by `meshAxes`.
+template <typename MeshAxesRange, typename MeshShapeRange>
+int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
+                                   MeshShapeRange &&meshShape) {
+  int64_t res = 1;
+
+  for (MeshAxis axis : meshAxes) {
+    auto axisSize = *(std::begin(meshShape) + axis);
+    if (ShapedType::isDynamic(axisSize)) {
+      return ShapedType::kDynamic;
+    }
+    res *= axisSize;
+  }
+
+  return res;
+}
+
 } // namespace mesh
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a28b9b429a246..da372706ec724 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -61,10 +61,6 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
     // A device mesh with 2 axes, the number of devices along both axes
     // is unknown
     mesh.mesh @mesh3(shape = ?x?)
-
-    // Used in the mesh sharding attribute to extend the standard tensor to
-    // distributed
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
     ```
   }];
   let arguments = (ins

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 0d21ecb4ebb44..3c145540356bd 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -55,7 +55,7 @@
 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"

diff  --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
index 634a94f8cec87..678a25f1c3cf5 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -5,9 +5,7 @@ add_mlir_dialect_library(MLIRMeshDialect
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
 
   DEPENDS
-  MLIRMeshOpsAttrIncGen
-  MLIRMeshOpsEnumsIncGen
-  MLIRMeshOpsIncGen
+  MLIRMeshIncGen
 
   LINK_LIBS PUBLIC
   MLIRArithDialect

diff  --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 994a017a1f46c..7c6a5e02ed583 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -7,7 +7,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -41,25 +43,7 @@
 using namespace mlir;
 using namespace mlir::mesh;
 
-#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
-
-template <typename It>
-static It canonicalizeSetAsArray(It begin, It end) {
-  llvm::sort(begin, end);
-  return std::unique(begin, end);
-}
-
-template <typename R>
-static auto canonicalizeSetAsArray(R &&range) {
-  return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
-}
-
-template <typename T>
-static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
-  auto newEnd = canonicalizeSetAsArray(vec);
-  vec.resize(newEnd - vec.begin());
-  return vec;
-}
+#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc"
 
 namespace {
 
@@ -101,7 +85,7 @@ void MeshDialect::initialize() {
       >();
   addAttributes<
 #define GET_ATTRDEF_LIST
-#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
+#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
       >();
 }
 
@@ -114,10 +98,10 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 // Mesh utilities
 //===----------------------------------------------------------------------===//
 
-static FailureOr<MeshOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
-                                 SymbolTableCollection &symbolTable) {
-  mesh::MeshOp mesh =
-      symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(op, meshSymbol);
+static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
+                                          FlatSymbolRefAttr meshSymbol,
+                                          SymbolTableCollection &symbolTable) {
+  mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
   if (!mesh) {
     return op->emitError() << "Undefined required mesh symbol \""
                            << meshSymbol.getValue() << "\".";
@@ -201,10 +185,6 @@ LogicalResult MeshOp::verify() {
   if (rank <= 0)
     return emitOpError("rank of mesh is expected to be a positive integer");
 
-  if (getShape().size() > size_t(rank))
-    return emitOpError(
-        "rank of shape is not expected to be larger than rank of mesh");
-
   for (int64_t dimSize : getShape()) {
     if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
       return emitOpError("dimension size of a mesh is expected to be "
@@ -220,7 +200,7 @@ LogicalResult MeshOp::verify() {
 
 LogicalResult
 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -322,7 +302,7 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
 
 LogicalResult
 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -360,7 +340,7 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
 
 LogicalResult
 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -428,7 +408,8 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
 template <typename Op>
 static FailureOr<MeshOp>
 getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
+  auto mesh =
+      ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -450,21 +431,6 @@ static auto product(R &&range) {
   return product(adl_begin(range), adl_end(range));
 }
 
-static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
-                                         ArrayRef<int64_t> meshShape) {
-  int64_t res = 1;
-
-  for (MeshAxis axis : meshAxes) {
-    if (ShapedType::isDynamic(meshShape[axis])) {
-      return ShapedType::kDynamic;
-    }
-    assert(size_t(axis) < meshShape.size());
-    res *= meshShape[axis];
-  }
-
-  return res;
-}
-
 static LogicalResult verifyDimensionCompatibility(Location loc,
                                                   int64_t expectedDimSize,
                                                   int64_t resultDimSize,
@@ -495,7 +461,7 @@ static LogicalResult verifyGatherOperandAndResultShape(
   ShapedType operandType = operand.getType().cast<ShapedType>();
   ShapedType resultType = result.getType().cast<ShapedType>();
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
     auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
     auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
@@ -529,7 +495,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
   }
 
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
   auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
   DimensionSize expectedResultConcatDimSize =
@@ -570,7 +536,7 @@ static LogicalResult verifyScatterOperandAndResultShape(
   }
 
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   auto operandScatterDimSize =
       DimensionSize(operandType.getDimSize(scatterAxis));
   if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
@@ -846,6 +812,6 @@ void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 #include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
+#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc"
 
-#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.cpp.inc"
+#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc"

diff  --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 3aed912fb43c6..f3cd12f38879d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Mesh/Transforms/Passes.h"
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
 #include "mlir/Pass/Pass.h"

diff  --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 593158d5f6d29..5554edac4d2f6 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -54,19 +55,6 @@ int64_t unshardDimension(int64_t dim, int64_t shardCount) {
   return dim * shardCount;
 }
 
-template <typename MeshShape, typename SplitAxes>
-int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) {
-  int64_t res = 1;
-  for (auto splitAxis : splitAxes) {
-    int64_t meshDimSize = meshShape[splitAxis];
-    if (ShapedType::isDynamic(meshDimSize)) {
-      return ShapedType::kDynamic;
-    }
-    res *= meshDimSize;
-  }
-  return res;
-}
-
 // Compute the shape for the tensor on each device in the mesh.
 // Example:
 // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
@@ -78,9 +66,9 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
   for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
-    outShape[tensorAxis] =
-        shardDimension(inShape[tensorAxis],
-                       shardCount(meshShape, innerSplitAxes.asArrayRef()));
+    outShape[tensorAxis] = shardDimension(
+        inShape[tensorAxis],
+        collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
   }
 }
 

diff  --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index 5c2344651bf60..03b1d9b349802 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectRegistry.h"

diff  --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 259e4ebf76757..3fa3ebd67b15e 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -13,10 +13,10 @@ mesh.mesh @mesh0(shape = -1)
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_
diff erent_subarray(
-    // expected-error at +1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0], [0]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -24,10 +24,10 @@ func.func @mesh_axis_duplicated_
diff erent_subarray(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_same_subarray(
-    // expected-error at +1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0, 0]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -35,10 +35,10 @@ func.func @mesh_axis_duplicated_same_subarray(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_bewteen_split_and_partial(
-    // expected-error at +1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0]], partial=max[0]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -46,10 +46,10 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_split_part(
-    // expected-error at +1 {{mesh axis is expected to be non-negative}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis is expected to be non-negative}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[-1]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -57,10 +57,10 @@ func.func @mesh_axis_negtive_in_split_part(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_partial(
-    // expected-error at +1 {{mesh axis is expected to be non-negative}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis is expected to be non-negative}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0]], partial=max[-1]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index dbaaff9c172fd..40a8469b26464 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -17,36 +17,12 @@ mesh.mesh @mesh4(shape = 3)
 // CHECK: mesh.mesh @mesh5(shape = ?)
 mesh.mesh @mesh5(shape = ?)
 
-// CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
-func.func @mesh_shard_encoding_fully_replicated(
-    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
-}
-
-// CHECK-LABEL: func @mesh_shard_encoding_1st_dim
-func.func @mesh_shard_encoding_1st_dim(
-    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}0]]>>
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
-}
-
-// CHECK-LABEL: func @mesh_shard_encoding_2nd_dim
-func.func @mesh_shard_encoding_2nd_dim(
-    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh1, {{\[\[}}], [0]]>>
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>>) -> 
-    tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>>
-}
-
-// CHECK-LABEL: func @mesh_shard_encoding_1st_and_3rd_dim
-func.func @mesh_shard_encoding_1st_and_3rd_dim(
-    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32, #mesh.shard<@mesh3, {{\[\[}}0], [], [1]]>>
-    %arg0 : tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>>) -> 
-            tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>> {
-  return %arg0 : tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>>
+// CHECK-LABEL: func @mesh_shard_op_fully_replicated
+// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+func.func @mesh_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}]]> : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh0, [[]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // CHECK-LABEL: func @mesh_shard_op_1st_dim

diff  --git a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
index 5018a308a3bdc..0bcc403a2734e 100644
--- a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
@@ -7,7 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/SymbolTable.h"
@@ -24,7 +23,6 @@ struct TestMultiIndexOpLoweringPass
 
   void runOnOperation() override;
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<mesh::MeshDialect>();
     mesh::processMultiIndexOpLoweringRegisterDialects(registry);
   }
   StringRef getArgument() const final {

diff  --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
index cd22530d51a75..512b16af64c94 100644
--- a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"


        


More information about the Mlir-commits mailing list