[Mlir-commits] [mlir] [mlir][mesh] Refactoring code organization, tests and docs (PR #79606)

Boian Petkantchin llvmlistbot at llvm.org
Fri Jan 26 07:29:55 PST 2024


https://github.com/sogartar created https://github.com/llvm/llvm-project/pull/79606

* Split out `MeshDialect.h` form `MeshOps.h` that defines the dialect class. Reduces include clutter if you care only about the dialect and not the ops.

* Expose functions `getMesh` and `collectiveProcessGroupSize`. There functions are useful for outside users of the dialect.

* Remove unused code.

* Remove examples and tests of mesh.sharding attribute in tensor encoding. Per the decision that Spmdization would be performed on sharding annotations and there will be no tensors with sharding specified in the type. For more info see this RFC comment:
https://discourse.llvm.org/t/rfc-sharding-framework-design-for-device-mesh/73533/81

>From 66744f085474a46e85cdaaedb8babf1e7468e1a0 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 25 Jan 2024 15:36:17 -0800
Subject: [PATCH] [mlir][mesh] Refactoring code organization, tests and docs

* Split out `MeshDialect.h` form `MeshOps.h` that defines the dialect class.
Reduces include clutter if you care only about the dialect and not the ops.

* Expose functions `getMesh` and `collectiveProcessGroupSize`.
There functions are useful for outside users of the dialect.

* Remove unused code.

* Remove examples and tests of mesh.sharding attribute in tensor encoding.
Per the decision that Spmdization would be performed on sharding annotations
and there will be no tensors with sharding specified in the type.
For more info see this RFC comment:
https://discourse.llvm.org/t/rfc-sharding-framework-design-for-device-mesh/73533/81
---
 .../mlir/Dialect/Mesh/IR/CMakeLists.txt       | 12 ++--
 mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td |  8 +--
 .../mlir/Dialect/Mesh/IR/MeshDialect.h        | 16 +++++
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   | 37 +++++++++--
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td  |  4 --
 mlir/include/mlir/InitAllDialects.h           |  2 +-
 mlir/lib/Dialect/Mesh/IR/CMakeLists.txt       |  4 +-
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 62 +++++--------------
 .../Mesh/Transforms/ShardingPropagation.cpp   |  1 +
 .../Dialect/Mesh/Transforms/Spmdization.cpp   | 20 ++----
 .../Dialect/Mesh/Transforms/Transforms.cpp    |  1 +
 mlir/test/Dialect/Mesh/invalid.mlir           | 40 ++++++------
 mlir/test/Dialect/Mesh/ops.mlir               | 36 ++---------
 .../Mesh/TestProcessMultiIndexOpLowering.cpp  |  2 -
 .../lib/Dialect/Mesh/TestSimplifications.cpp  |  2 +-
 15 files changed, 108 insertions(+), 139 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
index 79a99e58b4e8f10..a73ec701dbadcc2 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
@@ -2,11 +2,11 @@ add_mlir_dialect(MeshOps mesh)
 add_mlir_doc(MeshOps MeshOps Dialects/ -gen-dialect-doc -dialect=mesh)
 
 set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshOpsAttributes.h.inc -gen-attrdef-decls)
-mlir_tablegen(MeshOpsAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRMeshOpsAttrIncGen)
+mlir_tablegen(MeshAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(MeshAttributes.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRMeshAttrIncGen)
 
 set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshOpsEnums.h.inc -gen-enum-decls)
-mlir_tablegen(MeshOpsEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRMeshOpsEnumsIncGen)
+mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
+mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRMeshEnumsIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index e8353613cd0e12f..04929f4869273d0 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -123,18 +123,18 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     // The tensor is fully replicated on @mesh0.
     // Currently, there must be at least one sub-array present in axes, even
     // if it's empty. Otherwise, a parsing error will occur.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
+    #mesh.shard<@mesh0, [[]]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>
+    #mesh.shard<@mesh0, [[0]]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_sum along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
+    #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_max along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial = max[1]>
+    #mesh.shard<@mesh0, [[0]], partial = max[1]>
 
     // Could be used in the attribute of mesh.shard op
     %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
new file mode 100644
index 000000000000000..a1adbaa44406d6d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
@@ -0,0 +1,16 @@
+//===- MeshOps.h - Mesh Dialect ---------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_IR_MESHDIALECT_H
+#define MLIR_DIALECT_MESH_IR_MESHDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
+
+#endif // MLIR_DIALECT_MESH_IR_MESHDIALECT_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 83452dcc2e8abeb..8e5e0f541ba5ee6 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -15,7 +15,6 @@
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include <algorithm>
 
 namespace mlir {
 namespace mesh {
@@ -26,12 +25,10 @@ using MeshAxesAttr = DenseI16ArrayAttr;
 } // namespace mesh
 } // namespace mlir
 
-#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
-
-#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
+#include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
 
 #define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.h.inc"
+#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
@@ -51,6 +48,36 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
 
 Partial getPartialTypeFromReduction(IteratorType iType);
 
+inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+                            SymbolTableCollection &symbolTableCollection) {
+  return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
+      op, meshSymbol);
+}
+
+// Get the corresponding mesh op using the standard attribute nomenclature.
+template <typename Op>
+mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
+  return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
+}
+
+// Get the number of processes that participate in each group
+// induced by `meshAxes`.
+template <typename MeshAxesRange, typename MeshShapeRange>
+int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
+                                   MeshShapeRange &&meshShape) {
+  int64_t res = 1;
+
+  for (MeshAxis axis : meshAxes) {
+    auto axisSize = *(std::begin(meshShape) + axis);
+    if (ShapedType::isDynamic(axisSize)) {
+      return ShapedType::kDynamic;
+    }
+    res *= axisSize;
+  }
+
+  return res;
+}
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 7b301025e687ae3..c9efdcc3a68bb83 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -61,10 +61,6 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
     // A device mesh with 2 axes, the number of devices along both axes
     // is unknown
     mesh.mesh @mesh3(shape = ?x?)
-
-    // Used in the mesh sharding attribute to extend the standard tensor to
-    // distributed
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
     ```
   }];
   let arguments = (ins
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 19a62cadaa2e04f..d1ee32d7bac6130 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -54,7 +54,7 @@
 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
index 634a94f8cec8798..140fb553f84817a 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -5,8 +5,8 @@ add_mlir_dialect_library(MLIRMeshDialect
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
 
   DEPENDS
-  MLIRMeshOpsAttrIncGen
-  MLIRMeshOpsEnumsIncGen
+  MLIRMeshAttrIncGen
+  MLIRMeshEnumsIncGen
   MLIRMeshOpsIncGen
 
   LINK_LIBS PUBLIC
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 994a017a1f46c2e..b1cc36d1879e191 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
+
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Attributes.h"
@@ -43,24 +45,6 @@ using namespace mlir::mesh;
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
 
-template <typename It>
-static It canonicalizeSetAsArray(It begin, It end) {
-  llvm::sort(begin, end);
-  return std::unique(begin, end);
-}
-
-template <typename R>
-static auto canonicalizeSetAsArray(R &&range) {
-  return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
-}
-
-template <typename T>
-static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
-  auto newEnd = canonicalizeSetAsArray(vec);
-  vec.resize(newEnd - vec.begin());
-  return vec;
-}
-
 namespace {
 
 struct DimensionSize {
@@ -114,10 +98,10 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 // Mesh utilities
 //===----------------------------------------------------------------------===//
 
-static FailureOr<MeshOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
-                                 SymbolTableCollection &symbolTable) {
-  mesh::MeshOp mesh =
-      symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(op, meshSymbol);
+static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
+                                          FlatSymbolRefAttr meshSymbol,
+                                          SymbolTableCollection &symbolTable) {
+  mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
   if (!mesh) {
     return op->emitError() << "Undefined required mesh symbol \""
                            << meshSymbol.getValue() << "\".";
@@ -201,10 +185,6 @@ LogicalResult MeshOp::verify() {
   if (rank <= 0)
     return emitOpError("rank of mesh is expected to be a positive integer");
 
-  if (getShape().size() > size_t(rank))
-    return emitOpError(
-        "rank of shape is not expected to be larger than rank of mesh");
-
   for (int64_t dimSize : getShape()) {
     if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
       return emitOpError("dimension size of a mesh is expected to be "
@@ -220,7 +200,7 @@ LogicalResult MeshOp::verify() {
 
 LogicalResult
 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -322,7 +302,7 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
 
 LogicalResult
 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -360,7 +340,7 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
 
 LogicalResult
 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -428,7 +408,8 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
 template <typename Op>
 static FailureOr<MeshOp>
 getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
+  auto mesh =
+      ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -450,21 +431,6 @@ static auto product(R &&range) {
   return product(adl_begin(range), adl_end(range));
 }
 
-static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
-                                         ArrayRef<int64_t> meshShape) {
-  int64_t res = 1;
-
-  for (MeshAxis axis : meshAxes) {
-    if (ShapedType::isDynamic(meshShape[axis])) {
-      return ShapedType::kDynamic;
-    }
-    assert(size_t(axis) < meshShape.size());
-    res *= meshShape[axis];
-  }
-
-  return res;
-}
-
 static LogicalResult verifyDimensionCompatibility(Location loc,
                                                   int64_t expectedDimSize,
                                                   int64_t resultDimSize,
@@ -495,7 +461,7 @@ static LogicalResult verifyGatherOperandAndResultShape(
   ShapedType operandType = operand.getType().cast<ShapedType>();
   ShapedType resultType = result.getType().cast<ShapedType>();
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
     auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
     auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
@@ -529,7 +495,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
   }
 
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
   auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
   DimensionSize expectedResultConcatDimSize =
@@ -570,7 +536,7 @@ static LogicalResult verifyScatterOperandAndResultShape(
   }
 
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   auto operandScatterDimSize =
       DimensionSize(operandType.getDimSize(scatterAxis));
   if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 3aed912fb43c63e..f3cd12f38879d8e 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Mesh/Transforms/Passes.h"
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
 #include "mlir/Pass/Pass.h"
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 593158d5f6d293e..5554edac4d2f631 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -54,19 +55,6 @@ int64_t unshardDimension(int64_t dim, int64_t shardCount) {
   return dim * shardCount;
 }
 
-template <typename MeshShape, typename SplitAxes>
-int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) {
-  int64_t res = 1;
-  for (auto splitAxis : splitAxes) {
-    int64_t meshDimSize = meshShape[splitAxis];
-    if (ShapedType::isDynamic(meshDimSize)) {
-      return ShapedType::kDynamic;
-    }
-    res *= meshDimSize;
-  }
-  return res;
-}
-
 // Compute the shape for the tensor on each device in the mesh.
 // Example:
 // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
@@ -78,9 +66,9 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
   for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
-    outShape[tensorAxis] =
-        shardDimension(inShape[tensorAxis],
-                       shardCount(meshShape, innerSplitAxes.asArrayRef()));
+    outShape[tensorAxis] = shardDimension(
+        inShape[tensorAxis],
+        collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
   }
 }
 
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index 5c2344651bf60d0..03b1d9b34980281 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectRegistry.h"
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 259e4ebf76757c2..3fa3ebd67b15e7e 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -13,10 +13,10 @@ mesh.mesh @mesh0(shape = -1)
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_different_subarray(
-    // expected-error at +1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0], [0]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -24,10 +24,10 @@ func.func @mesh_axis_duplicated_different_subarray(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_same_subarray(
-    // expected-error at +1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0, 0]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -35,10 +35,10 @@ func.func @mesh_axis_duplicated_same_subarray(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_bewteen_split_and_partial(
-    // expected-error at +1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0]], partial=max[0]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -46,10 +46,10 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_split_part(
-    // expected-error at +1 {{mesh axis is expected to be non-negative}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis is expected to be non-negative}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[-1]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -57,10 +57,10 @@ func.func @mesh_axis_negtive_in_split_part(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_partial(
-    // expected-error at +1 {{mesh axis is expected to be non-negative}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis is expected to be non-negative}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0]], partial=max[-1]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index dbaaff9c172fd92..40a8469b264643f 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -17,36 +17,12 @@ mesh.mesh @mesh4(shape = 3)
 // CHECK: mesh.mesh @mesh5(shape = ?)
 mesh.mesh @mesh5(shape = ?)
 
-// CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
-func.func @mesh_shard_encoding_fully_replicated(
-    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
-}
-
-// CHECK-LABEL: func @mesh_shard_encoding_1st_dim
-func.func @mesh_shard_encoding_1st_dim(
-    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}0]]>>
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
-}
-
-// CHECK-LABEL: func @mesh_shard_encoding_2nd_dim
-func.func @mesh_shard_encoding_2nd_dim(
-    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh1, {{\[\[}}], [0]]>>
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>>) -> 
-    tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>>
-}
-
-// CHECK-LABEL: func @mesh_shard_encoding_1st_and_3rd_dim
-func.func @mesh_shard_encoding_1st_and_3rd_dim(
-    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32, #mesh.shard<@mesh3, {{\[\[}}0], [], [1]]>>
-    %arg0 : tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>>) -> 
-            tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>> {
-  return %arg0 : tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>>
+// CHECK-LABEL: func @mesh_shard_op_fully_replicated
+// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
+func.func @mesh_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}]]> : tensor<4x8xf32>
+  %0 = mesh.shard %arg0 to <@mesh0, [[]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // CHECK-LABEL: func @mesh_shard_op_1st_dim
diff --git a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
index 5018a308a3bdcd3..0bcc403a2734eda 100644
--- a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
@@ -7,7 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/SymbolTable.h"
@@ -24,7 +23,6 @@ struct TestMultiIndexOpLoweringPass
 
   void runOnOperation() override;
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<mesh::MeshDialect>();
     mesh::processMultiIndexOpLoweringRegisterDialects(registry);
   }
   StringRef getArgument() const final {
diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
index cd22530d51a75b5..512b16af64c945c 100644
--- a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"



More information about the Mlir-commits mailing list