[Mlir-commits] [mlir] [mlir][mesh] Add verification and canonicalization for some collectives (PR #74905)

Boian Petkantchin llvmlistbot at llvm.org
Fri Dec 8 17:00:47 PST 2023


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/74905

>From 6da9e6ae6f7288594dde5816422789ee9d601e6f Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 1 Dec 2023 09:47:42 -0800
Subject: [PATCH] [mlir][mesh] Add verification and canonicalization for some
 collectives

Add verification and canonicalization for
broadcast, gather, recv, reduce, scatter, send and shift.

The canonicalizations only remove trivial collectives with empty mesh_axes
attrubutes.
---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td |   7 +
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp         | 176 ++++++++-
 mlir/test/Dialect/Mesh/canonicalization.mlir |  79 ++++
 mlir/test/Dialect/Mesh/invalid.mlir          | 395 +++++++++++++++++++
 mlir/test/Dialect/Mesh/ops.mlir              | 245 ++++++++++++
 5 files changed, 881 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index e6cdba949b1721..fa6f9dbb79872f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -392,6 +392,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
@@ -454,6 +455,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
@@ -477,6 +479,7 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
     (`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
@@ -517,6 +520,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
@@ -645,6 +649,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
@@ -668,6 +673,7 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
     `destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
@@ -728,6 +734,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
     (`rotate` $rotate^)?
     attr-dict `:` type($input) `->` type($result)
   }];
+  let hasCanonicalizer = 1;
 }
 
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 1ba95f21ec7f3d..683b9adcd380a6 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Diagnostics.h"
@@ -231,6 +232,32 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
 
 } // namespace
 
+static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
+                                         ArrayRef<int64_t> device,
+                                         Operation::operand_range deviceDynamic,
+                                         ArrayRef<MeshAxis> meshAxes,
+                                         ArrayRef<int64_t> meshShape) {
+  if (device.size() != meshAxes.size()) {
+    return emitError(loc) << "In-group device \"" << deviceName
+                          << "\" has unexpected multi-index size "
+                          << device.size() << ". Expected " << meshAxes.size()
+                          << ".";
+  }
+
+  for (size_t i = 0; i < device.size(); ++i) {
+    if (!ShapedType::isDynamic(device[i]) &&
+        !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
+        meshShape[meshAxes[i]] <= device[i]) {
+      return emitError(loc)
+             << "Out of bounds coordinate " << i << " for in-group device \""
+             << deviceName << "\"."
+             << " Got " << device[i] << ", but expected value in the range [0, "
+             << (meshShape[meshAxes[i]] - 1) << "].";
+    }
+  }
+  return success();
+}
+
 static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
                                     SymbolTableCollection &symbolTable) {
   mesh::ClusterOp mesh =
@@ -338,7 +365,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
   return success();
 }
 
-static LogicalResult verifyAllGatherOperandAndResultShape(
+static LogicalResult verifyGatherOperandAndResultShape(
     Value operand, Value result, int64_t gatherAxis,
     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
   auto resultRank = result.getType().template cast<ShapedType>().getRank();
@@ -410,7 +437,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
   return success();
 }
 
-static LogicalResult verifyReduceScatterOperandAndResultShape(
+static LogicalResult verifyScatterOperandAndResultShape(
     Value operand, Value result, int64_t scatterAxis,
     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
   ShapedType operandType = operand.getType().cast<ShapedType>();
@@ -459,9 +486,9 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return failure();
   }
   auto gatherAxis = getGatherAxis().getSExtValue();
-  return verifyAllGatherOperandAndResultShape(getOperand(), getResult(),
-                                              gatherAxis, getMeshAxes(),
-                                              mesh.value().canonicalDimSizes());
+  return verifyGatherOperandAndResultShape(getOperand(), getResult(),
+                                           gatherAxis, getMeshAxes(),
+                                           mesh.value().canonicalDimSizes());
 }
 
 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -510,8 +537,22 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 
 LogicalResult
 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  return success();
+}
+
+void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                              MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -519,17 +560,48 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  auto gatherAxis = getGatherAxis().getSExtValue();
+  return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
+                                           getMeshAxes(),
+                                           mesh.value().canonicalDimSizes());
+}
+
+void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                           MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.receive op
+// mesh.recv op
 //===----------------------------------------------------------------------===//
 
 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (getSource() && failed(verifyInGroupDevice(
+                         getLoc(), getSourceAttrName(), getSource().value(),
+                         getSourceDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+  return success();
+}
+
+void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                         MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -537,8 +609,22 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  return success();
+}
+
+void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                           MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -552,7 +638,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return failure();
   }
 
-  return verifyReduceScatterOperandAndResultShape(
+  return verifyScatterOperandAndResultShape(
       getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
       mesh.value().canonicalDimSizes());
 }
@@ -567,8 +653,25 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 //===----------------------------------------------------------------------===//
 
 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  auto scatterAxis = getScatterAxis().getSExtValue();
+  return verifyScatterOperandAndResultShape(getInput(), getResult(),
+                                            scatterAxis, getMeshAxes(),
+                                            mesh.value().canonicalDimSizes());
+}
+
+void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                            MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -576,8 +679,22 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
+                                 getDestination(), getDestinationDynamic(),
+                                 getMeshAxes(), meshShape))) {
+    return failure();
+  }
+  return success();
+}
+
+void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                         MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -585,8 +702,25 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+
+  auto meshAxes = getMeshAxes();
+  auto shiftAxis = getShiftAxis().getZExtValue();
+  if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
+    return emitError() << "Invalid shift axis " << shiftAxis
+                       << ". It must be one of the grouping mesh axes.";
+  }
+
+  return success();
+}
+
+void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                          MLIRContext *context) {
+  // TODO: remove op when offset is 0 or if it is a rotate with and
+  // offset % sift_axis_mesh_dim_size == 0.
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index baee9faa645c93..0a00ab41268d01 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -63,6 +63,58 @@ func.func @all_gather_empty_mesh_axes(
   return %0 : tensor<4xf32>
 }
 
+// CHECK-LABEL: func @broadcast_empty_mesh_axes
+func.func @broadcast_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.broadcast
+  %0 = mesh.broadcast %arg0 on @mesh0
+    mesh_axes = []
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @gather_empty_mesh_axes
+func.func @gather_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.gather
+  %0 = mesh.gather %arg0 on @mesh0
+    mesh_axes = []
+    gather_axis = 0
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @receive_empty_mesh_axes
+func.func @receive_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.recv
+  %0 = mesh.recv %arg0 on @mesh0
+    mesh_axes = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_empty_mesh_axes
+func.func @reduce_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.reduce
+  %0 = mesh.reduce %arg0 on @mesh0
+    mesh_axes = []
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
 // CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
 func.func @reduce_scatter_empty_mesh_axes(
 // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
@@ -99,3 +151,30 @@ func.func @reduce_scatter_default_reduction(
     : tensor<4xf32> -> tensor<2xf64>
   return %0 : tensor<2xf64>
 }
+
+// CHECK-LABEL: func @scatter_empty_mesh_axes
+func.func @scatter_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.scatter
+  %0 = mesh.scatter %arg0 on @mesh0
+    mesh_axes = []
+    scatter_axis = 0
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @send_empty_mesh_axes
+func.func @send_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.send
+  %0 = mesh.send %arg0 on @mesh0
+    mesh_axes = []
+    destination = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index a26e3950186e95..03994f8f011e1f 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -298,6 +298,221 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
 
 // -----
 
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_root_dimension_out_of_bounds(
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
+  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+    root = [3]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_root_wrong_number_dimensions(
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
+  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+    root = [2, 2]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_different_input_and_result_type(
+    %arg0 : tensor<2xi8>) -> tensor<2xi16> {
+  // expected-error at +1 {{'mesh.broadcast' op failed to verify that all of {input, result} have same element type}}
+  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+    root = [2]
+    : (tensor<2xi8>) -> tensor<2xi16>
+  return %0 : tensor<2xi16>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_wrong_return_element_type(
+    %arg0 : tensor<1xf32>) -> tensor<1xi8> {
+  // expected-error at +1 {{'mesh.gather' op failed to verify that all of {input, result} have same element type}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+    : (tensor<1xf32>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_non_gather_axis_dimension_size(
+    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+    : (tensor<3x4xf32>) -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x2)
+
+func.func @gather_invalid_gather_axis_dimension_size(
+    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 root = [0]
+    : (tensor<3x4xf32>) -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_gather_axis_dynamic_dimension(
+    %arg0 : tensor<?xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+  %0 = mesh.gather %arg0 on @mesh0 gather_axis = 0 root = []
+    : (tensor<?xf32>) -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_gather_axis(
+    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 root = [0]
+    : (tensor<3xf32>) -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_negative_gather_axis(
+    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 root = [0]
+    : (tensor<3xf32>) -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @gather_root_dimension_out_of_bounds(
+    %arg0 : tensor<2xi8>) -> tensor<6xi8> {
+  // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+    root = [3]
+    : (tensor<2xi8>) -> tensor<6xi8>
+  return %0 : tensor<6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @gather_root_wrong_number_dimensions(
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+    root = [2, 2]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @receive_source_dimension_out_of_bounds(
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "source". Got 3, but expected value in the range [0, 2].}}
+  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+    source = [3]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @receive_source_wrong_number_dimensions(
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // expected-error at +1 {{In-group device "source" has unexpected multi-index size 2. Expected 1.}}
+  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+    source = [2, 2]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @receive_different_input_and_result_type(
+    %arg0 : tensor<2xi8>) -> tensor<2xi16> {
+  // expected-error at +1 {{'mesh.recv' op failed to verify that all of {input, result} have same element type}}
+  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+    source = [2]
+    : (tensor<2xi8>) -> tensor<2xi16>
+  return %0 : tensor<2xi16>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @reduce_root_dimension_out_of_bounds(
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
+  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+    root = [3]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @reduce_root_wrong_number_dimensions(
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
+  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+    root = [2, 2]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @reduce_different_input_and_result_shape(
+    %arg0 : tensor<2xi8>) -> tensor<3xi16> {
+  // expected-error at +1 {{'mesh.reduce' op failed to verify that all of {input, result} have same shape}}
+  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+    root = [2]
+    : (tensor<2xi8>) -> tensor<3xi16>
+  return %0 : tensor<3xi16>
+}
+
+// -----
+
 mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
 
 func.func @reduce_scatter_duplicate_mesh_axis(
@@ -343,3 +558,183 @@ func.func @reduce_scatter_invalid_operand_static_dimension_size(
     : tensor<4xf32> -> tensor<?xf64>
   return %0 : tensor<?xf64>
 }
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+
+func.func @scatter_duplicate_mesh_axis(
+    %arg0 : tensor<?xf32>) -> tensor<?xf32> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 0]
+    scatter_axis = 0 root = [0, 0]
+    : (tensor<?xf32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+
+func.func @scatter_invalid_dynamic_dimension(
+    %arg0 : tensor<?xf32>) -> tensor<2xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
+  %0 = mesh.scatter %arg0 on @mesh0
+    scatter_axis = 0 root = []
+    : (tensor<?xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+
+func.func @scatter_invalid_static_dimension_size(
+    %arg0 : tensor<3xf32>) -> tensor<2xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+    scatter_axis = 0 root = [1]
+    : (tensor<3xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+
+func.func @scatter_invalid_operand_static_dimension_size(
+    %arg0 : tensor<4xf32>) -> tensor<?xf32> {
+  // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+    scatter_axis = 0 root = [1]
+    : (tensor<4xf32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @scatter_root_dimension_out_of_bounds(
+    %arg0 : tensor<3xi8>) -> tensor<1xi8> {
+  // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
+  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+    scatter_axis = 0 root = [3]
+    : (tensor<3xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @scatter_root_wrong_number_dimensions(
+    %arg0 : tensor<3xi8>) -> tensor<1xi8> {
+  // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
+  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+    scatter_axis = 0 root = [2, 2]
+    : (tensor<3xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @send_destination_dimension_out_of_bounds(
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "destination". Got 3, but expected value in the range [0, 2].}}
+  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+    destination = [3]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @send_destination_wrong_number_dimensions(
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // expected-error at +1 {{In-group device "destination" has unexpected multi-index size 2. Expected 1.}}
+  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+    destination = [2, 2]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @send_different_input_and_result_type(
+    %arg0 : tensor<2xi8>) -> tensor<2xi16> {
+  // expected-error at +1 {{'mesh.send' op failed to verify that all of {input, result} have same element type}}
+  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+    destination = [2]
+    : (tensor<2xi8>) -> tensor<2xi16>
+  return %0 : tensor<2xi16>
+}
+
+// -----
+
+func.func @shift_invalid_mesh_symbol(
+    %arg0 : tensor<4xi8>) -> tensor<4xi8> {
+  // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+  %0 = mesh.shift %arg0 on @this_mesh_symbol_does_not_exist
+    shift_axis = 0 offset = -2
+    : tensor<4xi8> -> tensor<4xi8>
+  return %0 : tensor<4xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @shift_invalid_mesh_axis(
+    %arg0 : tensor<4xi8>) -> tensor<4xi8> {
+  // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [2]
+        shift_axis = 2 offset = -2
+    : tensor<4xi8> -> tensor<4xi8>
+  return %0 : tensor<4xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @shift_duplicate_mesh_axis(
+    %arg0 : tensor<4xi8>) -> tensor<4xi8> {
+  // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 1, 0]
+    shift_axis = 0 offset = -2
+    : tensor<4xi8> -> tensor<4xi8>
+  return %0 : tensor<4xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @shift_invalid_tensor_dimension_size(
+    %arg0 : tensor<4xi8>) -> tensor<5xi8> {
+  // expected-error at +1 {{'mesh.shift' op requires the same shape for all operands and results}}
+  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0]
+    shift_axis = 0 offset = 2
+    : tensor<4xi8> -> tensor<5xi8>
+  return %0 : tensor<5xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @shift_invalid_shift_axis(
+    %arg0 : tensor<4xi8>) -> tensor<4xi8> {
+  // expected-error at +1 {{Invalid shift axis 1. It must be one of the grouping mesh axes.}}
+  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0]
+    shift_axis = 1 offset = 2
+    : tensor<4xi8> -> tensor<4xi8>
+  return %0 : tensor<4xi8>
+}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 78ce276a7b33a3..8f8e309d18f156 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -228,6 +228,159 @@ func.func @all_to_all_non_divisible_split_axis_size(
   return %0 : tensor<?x12xi8>
 }
 
+// CHECK-LABEL: func @broadcast_static_root
+func.func @broadcast_static_root(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+    %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+  // CHECK-NEXT: mesh.broadcast %[[ARG]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-SAME: root = [0, 1]
+  // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<3x6xi8>
+  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
+    root = [0, 1]
+    : (tensor<3x6xi8>) -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// CHECK-LABEL: func @broadcast_dynamic_root
+func.func @broadcast_dynamic_root(
+    // CHECK-SAME: %[[ARG0:.*]]: tensor<3x6xi8>
+    %arg0 : tensor<3x6xi8>,
+    // CHECK-SAME: %[[ARG1:.*]]: index
+    %arg1 : index
+    ) -> tensor<3x6xi8> {
+  // CHECK-NEXT: mesh.broadcast %[[ARG0]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-SAME: root = [1, %[[ARG1]]]
+  // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<3x6xi8>
+  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
+    root = [1, %arg1]
+    : (tensor<3x6xi8>, index) -> tensor<3x6xi8>
+  return %0 : tensor<3x6xi8>
+}
+
+// CHECK-LABEL: func @gather_static_root
+func.func @gather_static_root(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+    %arg0 : tensor<3x6xi8>) -> tensor<24x6xi8> {
+  // CHECK-NEXT: mesh.gather %[[ARG]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-SAME: gather_axis = 0
+  // CHECK-SAME: root = [0, 1]
+  // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<24x6xi8>
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
+    gather_axis = 0
+    root = [0, 1]
+    : (tensor<3x6xi8>) -> tensor<24x6xi8>
+  return %0 : tensor<24x6xi8>
+}
+
+// CHECK-LABEL: func @gather_dynamic_root
+func.func @gather_dynamic_root(
+    // CHECK-SAME: %[[ARG0:.*]]: tensor<3x6xi8>
+    %arg0 : tensor<3x6xi8>,
+    // CHECK-SAME: %[[ARG1:.*]]: index
+    %arg1 : index
+    ) -> tensor<24x6xi8> {
+  // CHECK-NEXT: mesh.gather %[[ARG0]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-SAME: gather_axis = 0
+  // CHECK-SAME: root = [1, %[[ARG1]]]
+  // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<24x6xi8>
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
+    gather_axis = 0
+    root = [1, %arg1]
+    : (tensor<3x6xi8>, index) -> tensor<24x6xi8>
+  return %0 : tensor<24x6xi8>
+}
+
+// CHECK-LABEL: func @receive_static_source
+func.func @receive_static_source(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // CHECK-NEXT: mesh.recv %[[ARG]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-SAME: source = [0, 1]
+  // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
+  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+    source = [0, 1]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @receive_dynamic_source
+func.func @receive_dynamic_source(
+    // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
+    %arg0 : tensor<2xi8>,
+    // CHECK-SAME: %[[ARG1:.*]]: index
+    %arg1 : index
+    ) -> tensor<2xi8> {
+  // CHECK-NEXT: mesh.recv %[[ARG0]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-SAME: source = [1, %[[ARG1]]]
+  // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
+  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+    source = [1, %arg1]
+    : (tensor<2xi8>, index) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @receive_no_source
+func.func @receive_no_source(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // CHECK-NEXT: mesh.recv %[[ARG]]
+  // CHECK-NOT: source
+  %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @reduce_static_root
+func.func @reduce_static_root(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // CHECK-NEXT: mesh.reduce %[[ARG]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-SAME: root = [0, 1]
+  // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
+  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+    root = [0, 1]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @reduce_dynamic_root
+func.func @reduce_dynamic_root(
+    // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
+    %arg0 : tensor<2xi8>,
+    // CHECK-SAME: %[[ARG1:.*]]: index
+    %arg1 : index
+    ) -> tensor<2xi8> {
+  // CHECK-NEXT: mesh.reduce %[[ARG0]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-SAME: root = [1, %[[ARG1]]]
+  // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
+  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+    root = [1, %arg1]
+    : (tensor<2xi8>, index) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @reduce_different_return_element_type
+func.func @reduce_different_return_element_type(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+    %arg0 : tensor<2xi8>) -> tensor<2xi16> {
+  // CHECK-NEXT: mesh.reduce %[[ARG]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-SAME: root = [0, 1]
+  // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi16>
+  %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+    root = [0, 1]
+    : (tensor<2xi8>) -> tensor<2xi16>
+  return %0 : tensor<2xi16>
+}
+
 // CHECK-LABEL: func @reduce_scatter_static_dimensions
 func.func @reduce_scatter_static_dimensions(
     // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
@@ -252,3 +405,95 @@ func.func @reduce_scatter_dynamic_dimensions(
     : tensor<?xf32> -> tensor<?xf64>
   return %0 : tensor<?xf64>
 }
+
+// CHECK-LABEL: func @scatter_static_dimensions
+func.func @scatter_static_dimensions(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+    %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
+  // CHECK-NEXT: mesh.scatter %[[ARG]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [2]
+  // CHECK-SAME: scatter_axis = 1 root = [1]
+  // CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32>
+  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [2]
+    scatter_axis = 1 root = [1]
+    : (tensor<3x4xf32>) -> tensor<3x1xf32>
+  return %0 : tensor<3x1xf32>
+}
+
+// CHECK-LABEL: func @scatter_dynamic_dimensions
+func.func @scatter_dynamic_dimensions(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
+    %arg0 : tensor<?xf32>) -> tensor<?xf32> {
+  // CHECK-NEXT: mesh.scatter %[[ARG]]
+  // CHECK-SAME: on @mesh3 mesh_axes = [0, 1]
+  // CHECK-SAME: scatter_axis = 0 root = [1, 2]
+  // CHECK-SAME: : (tensor<?xf32>) -> tensor<?xf32>
+  %0 = mesh.scatter %arg0 on @mesh3 mesh_axes = [0, 1]
+    scatter_axis = 0 root = [1, 2]
+    : (tensor<?xf32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func @scatter_dynamic_root
+func.func @scatter_dynamic_root(
+    // CHECK-SAME: %[[ARG0:.*]]: tensor<8xi8>
+    %arg0 : tensor<8xi8>,
+    // CHECK-SAME: %[[ARG1:.*]]: index
+    %arg1 : index
+    ) -> tensor<1xi8> {
+  // CHECK-NEXT: mesh.scatter %[[ARG0]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-SAME: scatter_axis = 0
+  // CHECK-SAME: root = [1, %[[ARG1]]]
+  // CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8>
+  %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 2]
+    scatter_axis = 0
+    root = [1, %arg1]
+    : (tensor<8xi8>, index) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// CHECK-LABEL: func @send_static_destination
+func.func @send_static_destination(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // CHECK-NEXT: mesh.send %[[ARG]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-SAME: destination = [0, 1]
+  // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
+  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
+    destination = [0, 1]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @send_dynamic_destination
+func.func @send_dynamic_destination(
+    // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
+    %arg0 : tensor<2xi8>,
+    // CHECK-SAME: %[[ARG1:.*]]: index
+    %arg1 : index
+    ) -> tensor<2xi8> {
+  // CHECK-NEXT: mesh.send %[[ARG0]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-SAME: destination = [1, %[[ARG1]]]
+  // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
+  %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
+    destination = [1, %arg1]
+    : (tensor<2xi8>, index) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @shift
+func.func @shift(
+    // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // CHECK-NEXT: mesh.shift %[[ARG]]
+  // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+  // CHECK-SAME: shift_axis = 2 offset = -2 rotate
+  // CHECK-SAME: : tensor<2xi8> -> tensor<2xi8>
+  %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 2]
+    shift_axis = 2 offset = -2 rotate
+    : tensor<2xi8> -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}



More information about the Mlir-commits mailing list