[Mlir-commits] [mlir] [mlir][mesh] Refactoring code organization, tests and docs (PR #79606)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 26 07:30:26 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

<details>
<summary>Changes</summary>

* Split out `MeshDialect.h` form `MeshOps.h` that defines the dialect class. Reduces include clutter if you care only about the dialect and not the ops.

* Expose functions `getMesh` and `collectiveProcessGroupSize`. There functions are useful for outside users of the dialect.

* Remove unused code.

* Remove examples and tests of mesh.sharding attribute in tensor encoding. Per the decision that Spmdization would be performed on sharding annotations and there will be no tensors with sharding specified in the type. For more info see this RFC comment:
https://discourse.llvm.org/t/rfc-sharding-framework-design-for-device-mesh/73533/81

---

Patch is 23.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79606.diff


15 Files Affected:

- (modified) mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt (+6-6) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+4-4) 
- (added) mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h (+16) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+32-5) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (-4) 
- (modified) mlir/include/mlir/InitAllDialects.h (+1-1) 
- (modified) mlir/lib/Dialect/Mesh/IR/CMakeLists.txt (+2-2) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+14-48) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+1) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+4-16) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+1) 
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+20-20) 
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+6-30) 
- (modified) mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp (-2) 
- (modified) mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
index 79a99e58b4e8f1..a73ec701dbadcc 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt
@@ -2,11 +2,11 @@ add_mlir_dialect(MeshOps mesh)
 add_mlir_doc(MeshOps MeshOps Dialects/ -gen-dialect-doc -dialect=mesh)
 
 set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshOpsAttributes.h.inc -gen-attrdef-decls)
-mlir_tablegen(MeshOpsAttributes.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRMeshOpsAttrIncGen)
+mlir_tablegen(MeshAttributes.h.inc -gen-attrdef-decls)
+mlir_tablegen(MeshAttributes.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRMeshAttrIncGen)
 
 set(LLVM_TARGET_DEFINITIONS MeshBase.td)
-mlir_tablegen(MeshOpsEnums.h.inc -gen-enum-decls)
-mlir_tablegen(MeshOpsEnums.cpp.inc -gen-enum-defs)
-add_public_tablegen_target(MLIRMeshOpsEnumsIncGen)
+mlir_tablegen(MeshEnums.h.inc -gen-enum-decls)
+mlir_tablegen(MeshEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRMeshEnumsIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index e8353613cd0e12..04929f4869273d 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -123,18 +123,18 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
     // The tensor is fully replicated on @mesh0.
     // Currently, there must be at least one sub-array present in axes, even
     // if it's empty. Otherwise, a parsing error will occur.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>
+    #mesh.shard<@mesh0, [[]]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>
+    #mesh.shard<@mesh0, [[0]]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_sum along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
+    #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
 
     // The tensor is sharded on the first dimension along axis 0 of @mesh0 and
     // it is also a partial_max along mesh axis 1.
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial = max[1]>
+    #mesh.shard<@mesh0, [[0]], partial = max[1]>
 
     // Could be used in the attribute of mesh.shard op
     %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
new file mode 100644
index 00000000000000..a1adbaa44406d6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshDialect.h
@@ -0,0 +1,16 @@
+//===- MeshOps.h - Mesh Dialect ---------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_IR_MESHDIALECT_H
+#define MLIR_DIALECT_MESH_IR_MESHDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
+
+#endif // MLIR_DIALECT_MESH_IR_MESHDIALECT_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 83452dcc2e8abe..8e5e0f541ba5ee 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -15,7 +15,6 @@
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include <algorithm>
 
 namespace mlir {
 namespace mesh {
@@ -26,12 +25,10 @@ using MeshAxesAttr = DenseI16ArrayAttr;
 } // namespace mesh
 } // namespace mlir
 
-#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
-
-#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
+#include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
 
 #define GET_ATTRDEF_CLASSES
-#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.h.inc"
+#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
@@ -51,6 +48,36 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
 
 Partial getPartialTypeFromReduction(IteratorType iType);
 
+inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+                            SymbolTableCollection &symbolTableCollection) {
+  return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
+      op, meshSymbol);
+}
+
+// Get the corresponding mesh op using the standard attribute nomenclature.
+template <typename Op>
+mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
+  return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
+}
+
+// Get the number of processes that participate in each group
+// induced by `meshAxes`.
+template <typename MeshAxesRange, typename MeshShapeRange>
+int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
+                                   MeshShapeRange &&meshShape) {
+  int64_t res = 1;
+
+  for (MeshAxis axis : meshAxes) {
+    auto axisSize = *(std::begin(meshShape) + axis);
+    if (ShapedType::isDynamic(axisSize)) {
+      return ShapedType::kDynamic;
+    }
+    res *= axisSize;
+  }
+
+  return res;
+}
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 7b301025e687ae..c9efdcc3a68bb8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -61,10 +61,6 @@ def Mesh_MeshOp : Mesh_Op<"mesh", [Symbol]> {
     // A device mesh with 2 axes, the number of devices along both axes
     // is unknown
     mesh.mesh @mesh3(shape = ?x?)
-
-    // Used in the mesh sharding attribute to extend the standard tensor to
-    // distributed
-    tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>
     ```
   }];
   let arguments = (ins
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 19a62cadaa2e04..d1ee32d7bac613 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -54,7 +54,7 @@
 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
-#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
index 634a94f8cec879..140fb553f84817 100644
--- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt
@@ -5,8 +5,8 @@ add_mlir_dialect_library(MLIRMeshDialect
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
 
   DEPENDS
-  MLIRMeshOpsAttrIncGen
-  MLIRMeshOpsEnumsIncGen
+  MLIRMeshAttrIncGen
+  MLIRMeshEnumsIncGen
   MLIRMeshOpsIncGen
 
   LINK_LIBS PUBLIC
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 994a017a1f46c2..b1cc36d1879e19 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
+
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Attributes.h"
@@ -43,24 +45,6 @@ using namespace mlir::mesh;
 
 #include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
 
-template <typename It>
-static It canonicalizeSetAsArray(It begin, It end) {
-  llvm::sort(begin, end);
-  return std::unique(begin, end);
-}
-
-template <typename R>
-static auto canonicalizeSetAsArray(R &&range) {
-  return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
-}
-
-template <typename T>
-static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
-  auto newEnd = canonicalizeSetAsArray(vec);
-  vec.resize(newEnd - vec.begin());
-  return vec;
-}
-
 namespace {
 
 struct DimensionSize {
@@ -114,10 +98,10 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 // Mesh utilities
 //===----------------------------------------------------------------------===//
 
-static FailureOr<MeshOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
-                                 SymbolTableCollection &symbolTable) {
-  mesh::MeshOp mesh =
-      symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>(op, meshSymbol);
+static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
+                                          FlatSymbolRefAttr meshSymbol,
+                                          SymbolTableCollection &symbolTable) {
+  mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
   if (!mesh) {
     return op->emitError() << "Undefined required mesh symbol \""
                            << meshSymbol.getValue() << "\".";
@@ -201,10 +185,6 @@ LogicalResult MeshOp::verify() {
   if (rank <= 0)
     return emitOpError("rank of mesh is expected to be a positive integer");
 
-  if (getShape().size() > size_t(rank))
-    return emitOpError(
-        "rank of shape is not expected to be larger than rank of mesh");
-
   for (int64_t dimSize : getShape()) {
     if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
       return emitOpError("dimension size of a mesh is expected to be "
@@ -220,7 +200,7 @@ LogicalResult MeshOp::verify() {
 
 LogicalResult
 MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -322,7 +302,7 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
 
 LogicalResult
 ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -360,7 +340,7 @@ void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
 
 LogicalResult
 ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+  auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -428,7 +408,8 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
 template <typename Op>
 static FailureOr<MeshOp>
 getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
-  auto mesh = ::getMesh(op.getOperation(), op.getMeshAttr(), symbolTable);
+  auto mesh =
+      ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable);
   if (failed(mesh)) {
     return failure();
   }
@@ -450,21 +431,6 @@ static auto product(R &&range) {
   return product(adl_begin(range), adl_end(range));
 }
 
-static int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
-                                         ArrayRef<int64_t> meshShape) {
-  int64_t res = 1;
-
-  for (MeshAxis axis : meshAxes) {
-    if (ShapedType::isDynamic(meshShape[axis])) {
-      return ShapedType::kDynamic;
-    }
-    assert(size_t(axis) < meshShape.size());
-    res *= meshShape[axis];
-  }
-
-  return res;
-}
-
 static LogicalResult verifyDimensionCompatibility(Location loc,
                                                   int64_t expectedDimSize,
                                                   int64_t resultDimSize,
@@ -495,7 +461,7 @@ static LogicalResult verifyGatherOperandAndResultShape(
   ShapedType operandType = operand.getType().cast<ShapedType>();
   ShapedType resultType = result.getType().cast<ShapedType>();
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
     auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
     auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
@@ -529,7 +495,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
   }
 
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis));
   auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis));
   DimensionSize expectedResultConcatDimSize =
@@ -570,7 +536,7 @@ static LogicalResult verifyScatterOperandAndResultShape(
   }
 
   auto deviceGroupSize =
-      DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+      DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
   auto operandScatterDimSize =
       DimensionSize(operandType.getDimSize(scatterAxis));
   if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() &&
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 3aed912fb43c63..f3cd12f38879d8 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Mesh/Transforms/Passes.h"
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
 #include "mlir/Pass/Pass.h"
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 593158d5f6d293..5554edac4d2f63 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -54,19 +55,6 @@ int64_t unshardDimension(int64_t dim, int64_t shardCount) {
   return dim * shardCount;
 }
 
-template <typename MeshShape, typename SplitAxes>
-int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) {
-  int64_t res = 1;
-  for (auto splitAxis : splitAxes) {
-    int64_t meshDimSize = meshShape[splitAxis];
-    if (ShapedType::isDynamic(meshDimSize)) {
-      return ShapedType::kDynamic;
-    }
-    res *= meshDimSize;
-  }
-  return res;
-}
-
 // Compute the shape for the tensor on each device in the mesh.
 // Example:
 // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
@@ -78,9 +66,9 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
   for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
-    outShape[tensorAxis] =
-        shardDimension(inShape[tensorAxis],
-                       shardCount(meshShape, innerSplitAxes.asArrayRef()));
+    outShape[tensorAxis] = shardDimension(
+        inShape[tensorAxis],
+        collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
   }
 }
 
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index 5c2344651bf60d..03b1d9b3498028 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectRegistry.h"
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 259e4ebf76757c..3fa3ebd67b15e7 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -13,10 +13,10 @@ mesh.mesh @mesh0(shape = -1)
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_different_subarray(
-    // expected-error at +1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [0]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0], [0]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -24,10 +24,10 @@ func.func @mesh_axis_duplicated_different_subarray(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_same_subarray(
-    // expected-error at +1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0, 0]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0, 0]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -35,10 +35,10 @@ func.func @mesh_axis_duplicated_same_subarray(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_duplicated_bewteen_split_and_partial(
-    // expected-error at +1 {{mesh axis duplicated}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[0]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis duplicated}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0]], partial=max[0]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -46,10 +46,10 @@ func.func @mesh_axis_duplicated_bewteen_split_and_partial(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_split_part(
-    // expected-error at +1 {{mesh axis is expected to be non-negative}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[-1]]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis is expected to be non-negative}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[-1]]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
@@ -57,10 +57,10 @@ func.func @mesh_axis_negtive_in_split_part(
 mesh.mesh @mesh0(shape = 2x4)
 
 func.func @mesh_axis_negtive_in_partial(
-    // expected-error at +1 {{mesh axis is expected to be non-negative}}
-    %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>) -> 
-            tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>> {
-  return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]], partial=max[-1]>>
+    %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
+  // expected-error at +1 {{mesh axis is expected to be non-negative}}
+  %0 = mesh.shard %arg0 to <@mesh0, [[0]], partial=max[-1]> : tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index dbaaff9c172fd9..40a8469b264643 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -17,36 +17,12 @@ mesh.mesh @mesh4(shape = 3)
 // CHECK: mesh.mesh @mesh5(shape = ?)
 mesh.mesh @mesh5(shape = ?)
 
-// CHECK-LABEL: func @mesh_shard_encoding_fully_replicated
-func.func @mesh_shard_encoding_fully_replicated(
-    // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>>
-    %arg0 : tensor<4x8xf32, #m...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/79606


More information about the Mlir-commits mailing list