[Mlir-commits] [mlir] [mlir][mesh] Add verification and canonicalization for some collectives (PR #74905)
Boian Petkantchin
llvmlistbot at llvm.org
Thu Dec 14 14:25:51 PST 2023
https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/74905
>From 29908d9d2770c59978f90d8504e3a0df22e2fc0e Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 1 Dec 2023 09:47:42 -0800
Subject: [PATCH 1/2] [mlir][mesh] Add verification and canonicalization for
some collectives
Add verification and canonicalization for
broadcast, gather, recv, reduce, scatter, send and shift.
The canonicalizations only remove trivial collectives with empty mesh_axes
attrubutes.
---
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 7 +
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 176 ++++++++-
mlir/test/Dialect/Mesh/canonicalization.mlir | 79 ++++
mlir/test/Dialect/Mesh/invalid.mlir | 395 +++++++++++++++++++
mlir/test/Dialect/Mesh/ops.mlir | 245 ++++++++++++
5 files changed, 881 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index e6cdba949b1721..fa6f9dbb79872f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -392,6 +392,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
@@ -454,6 +455,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
@@ -477,6 +479,7 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
@@ -517,6 +520,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
@@ -645,6 +649,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
@@ -668,6 +673,7 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
@@ -728,6 +734,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
(`rotate` $rotate^)?
attr-dict `:` type($input) `->` type($result)
}];
+ let hasCanonicalizer = 1;
}
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 1ba95f21ec7f3d..683b9adcd380a6 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
@@ -231,6 +232,32 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
} // namespace
+static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
+ ArrayRef<int64_t> device,
+ Operation::operand_range deviceDynamic,
+ ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ if (device.size() != meshAxes.size()) {
+ return emitError(loc) << "In-group device \"" << deviceName
+ << "\" has unexpected multi-index size "
+ << device.size() << ". Expected " << meshAxes.size()
+ << ".";
+ }
+
+ for (size_t i = 0; i < device.size(); ++i) {
+ if (!ShapedType::isDynamic(device[i]) &&
+ !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
+ meshShape[meshAxes[i]] <= device[i]) {
+ return emitError(loc)
+ << "Out of bounds coordinate " << i << " for in-group device \""
+ << deviceName << "\"."
+ << " Got " << device[i] << ", but expected value in the range [0, "
+ << (meshShape[meshAxes[i]] - 1) << "].";
+ }
+ }
+ return success();
+}
+
static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTable) {
mesh::ClusterOp mesh =
@@ -338,7 +365,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
return success();
}
-static LogicalResult verifyAllGatherOperandAndResultShape(
+static LogicalResult verifyGatherOperandAndResultShape(
Value operand, Value result, int64_t gatherAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
auto resultRank = result.getType().template cast<ShapedType>().getRank();
@@ -410,7 +437,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
return success();
}
-static LogicalResult verifyReduceScatterOperandAndResultShape(
+static LogicalResult verifyScatterOperandAndResultShape(
Value operand, Value result, int64_t scatterAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
ShapedType operandType = operand.getType().cast<ShapedType>();
@@ -459,9 +486,9 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
auto gatherAxis = getGatherAxis().getSExtValue();
- return verifyAllGatherOperandAndResultShape(getOperand(), getResult(),
- gatherAxis, getMeshAxes(),
- mesh.value().canonicalDimSizes());
+ return verifyGatherOperandAndResultShape(getOperand(), getResult(),
+ gatherAxis, getMeshAxes(),
+ mesh.value().canonicalDimSizes());
}
void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -510,8 +537,22 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
LogicalResult
BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ return success();
+}
+
+void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -519,17 +560,48 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ auto gatherAxis = getGatherAxis().getSExtValue();
+ return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
+ getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+}
+
+void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
}
//===----------------------------------------------------------------------===//
-// mesh.receive op
+// mesh.recv op
//===----------------------------------------------------------------------===//
LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (getSource() && failed(verifyInGroupDevice(
+ getLoc(), getSourceAttrName(), getSource().value(),
+ getSourceDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+ return success();
+}
+
+void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -537,8 +609,22 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ return success();
+}
+
+void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -552,7 +638,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
- return verifyReduceScatterOperandAndResultShape(
+ return verifyScatterOperandAndResultShape(
getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
mesh.value().canonicalDimSizes());
}
@@ -567,8 +653,25 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
//===----------------------------------------------------------------------===//
LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ auto scatterAxis = getScatterAxis().getSExtValue();
+ return verifyScatterOperandAndResultShape(getInput(), getResult(),
+ scatterAxis, getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+}
+
+void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -576,8 +679,22 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
+ getDestination(), getDestinationDynamic(),
+ getMeshAxes(), meshShape))) {
+ return failure();
+ }
+ return success();
+}
+
+void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -585,8 +702,25 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+
+ auto meshAxes = getMeshAxes();
+ auto shiftAxis = getShiftAxis().getZExtValue();
+ if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
+ return emitError() << "Invalid shift axis " << shiftAxis
+ << ". It must be one of the grouping mesh axes.";
+ }
+
+ return success();
+}
+
+void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ // TODO: remove op when offset is 0 or if it is a rotate with and
+ // offset % sift_axis_mesh_dim_size == 0.
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index baee9faa645c93..0a00ab41268d01 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -63,6 +63,58 @@ func.func @all_gather_empty_mesh_axes(
return %0 : tensor<4xf32>
}
+// CHECK-LABEL: func @broadcast_empty_mesh_axes
+func.func @broadcast_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.broadcast
+ %0 = mesh.broadcast %arg0 on @mesh0
+ mesh_axes = []
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @gather_empty_mesh_axes
+func.func @gather_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.gather
+ %0 = mesh.gather %arg0 on @mesh0
+ mesh_axes = []
+ gather_axis = 0
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @receive_empty_mesh_axes
+func.func @receive_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.recv
+ %0 = mesh.recv %arg0 on @mesh0
+ mesh_axes = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_empty_mesh_axes
+func.func @reduce_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.reduce
+ %0 = mesh.reduce %arg0 on @mesh0
+ mesh_axes = []
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
func.func @reduce_scatter_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
@@ -99,3 +151,30 @@ func.func @reduce_scatter_default_reduction(
: tensor<4xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
+
+// CHECK-LABEL: func @scatter_empty_mesh_axes
+func.func @scatter_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.scatter
+ %0 = mesh.scatter %arg0 on @mesh0
+ mesh_axes = []
+ scatter_axis = 0
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @send_empty_mesh_axes
+func.func @send_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.send
+ %0 = mesh.send %arg0 on @mesh0
+ mesh_axes = []
+ destination = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index a26e3950186e95..03994f8f011e1f 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -298,6 +298,221 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
// -----
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_root_dimension_out_of_bounds(
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
+ %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ root = [3]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_root_wrong_number_dimensions(
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
+ %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ root = [2, 2]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_different_input_and_result_type(
+ %arg0 : tensor<2xi8>) -> tensor<2xi16> {
+ // expected-error at +1 {{'mesh.broadcast' op failed to verify that all of {input, result} have same element type}}
+ %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ root = [2]
+ : (tensor<2xi8>) -> tensor<2xi16>
+ return %0 : tensor<2xi16>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_wrong_return_element_type(
+ %arg0 : tensor<1xf32>) -> tensor<1xi8> {
+ // expected-error at +1 {{'mesh.gather' op failed to verify that all of {input, result} have same element type}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+ : (tensor<1xf32>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_non_gather_axis_dimension_size(
+ %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+ : (tensor<3x4xf32>) -> tensor<3x5xf32>
+ return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x2)
+
+func.func @gather_invalid_gather_axis_dimension_size(
+ %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 root = [0]
+ : (tensor<3x4xf32>) -> tensor<3x5xf32>
+ return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_gather_axis_dynamic_dimension(
+ %arg0 : tensor<?xf32>) -> tensor<3xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+ %0 = mesh.gather %arg0 on @mesh0 gather_axis = 0 root = []
+ : (tensor<?xf32>) -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_gather_axis(
+ %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+ // expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 root = [0]
+ : (tensor<3xf32>) -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_negative_gather_axis(
+ %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+ // expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 root = [0]
+ : (tensor<3xf32>) -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @gather_root_dimension_out_of_bounds(
+ %arg0 : tensor<2xi8>) -> tensor<6xi8> {
+ // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+ root = [3]
+ : (tensor<2xi8>) -> tensor<6xi8>
+ return %0 : tensor<6xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @gather_root_wrong_number_dimensions(
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0
+ root = [2, 2]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @receive_source_dimension_out_of_bounds(
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "source". Got 3, but expected value in the range [0, 2].}}
+ %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+ source = [3]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @receive_source_wrong_number_dimensions(
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // expected-error at +1 {{In-group device "source" has unexpected multi-index size 2. Expected 1.}}
+ %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+ source = [2, 2]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @receive_different_input_and_result_type(
+ %arg0 : tensor<2xi8>) -> tensor<2xi16> {
+ // expected-error at +1 {{'mesh.recv' op failed to verify that all of {input, result} have same element type}}
+ %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0]
+ source = [2]
+ : (tensor<2xi8>) -> tensor<2xi16>
+ return %0 : tensor<2xi16>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @reduce_root_dimension_out_of_bounds(
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
+ %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+ root = [3]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @reduce_root_wrong_number_dimensions(
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
+ %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+ root = [2, 2]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @reduce_different_input_and_result_shape(
+ %arg0 : tensor<2xi8>) -> tensor<3xi16> {
+ // expected-error at +1 {{'mesh.reduce' op failed to verify that all of {input, result} have same shape}}
+ %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0]
+ root = [2]
+ : (tensor<2xi8>) -> tensor<3xi16>
+ return %0 : tensor<3xi16>
+}
+
+// -----
+
mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
func.func @reduce_scatter_duplicate_mesh_axis(
@@ -343,3 +558,183 @@ func.func @reduce_scatter_invalid_operand_static_dimension_size(
: tensor<4xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+
+func.func @scatter_duplicate_mesh_axis(
+ %arg0 : tensor<?xf32>) -> tensor<?xf32> {
+ // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+ %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 0]
+ scatter_axis = 0 root = [0, 0]
+ : (tensor<?xf32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+
+func.func @scatter_invalid_dynamic_dimension(
+ %arg0 : tensor<?xf32>) -> tensor<2xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
+ %0 = mesh.scatter %arg0 on @mesh0
+ scatter_axis = 0 root = []
+ : (tensor<?xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+
+func.func @scatter_invalid_static_dimension_size(
+ %arg0 : tensor<3xf32>) -> tensor<2xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+ %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+ scatter_axis = 0 root = [1]
+ : (tensor<3xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 3)
+
+func.func @scatter_invalid_operand_static_dimension_size(
+ %arg0 : tensor<4xf32>) -> tensor<?xf32> {
+ // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+ %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+ scatter_axis = 0 root = [1]
+ : (tensor<4xf32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @scatter_root_dimension_out_of_bounds(
+ %arg0 : tensor<3xi8>) -> tensor<1xi8> {
+ // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
+ %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+ scatter_axis = 0 root = [3]
+ : (tensor<3xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @scatter_root_wrong_number_dimensions(
+ %arg0 : tensor<3xi8>) -> tensor<1xi8> {
+ // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
+ %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
+ scatter_axis = 0 root = [2, 2]
+ : (tensor<3xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @send_destination_dimension_out_of_bounds(
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "destination". Got 3, but expected value in the range [0, 2].}}
+ %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+ destination = [3]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @send_destination_wrong_number_dimensions(
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // expected-error at +1 {{In-group device "destination" has unexpected multi-index size 2. Expected 1.}}
+ %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+ destination = [2, 2]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @send_different_input_and_result_type(
+ %arg0 : tensor<2xi8>) -> tensor<2xi16> {
+ // expected-error at +1 {{'mesh.send' op failed to verify that all of {input, result} have same element type}}
+ %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0]
+ destination = [2]
+ : (tensor<2xi8>) -> tensor<2xi16>
+ return %0 : tensor<2xi16>
+}
+
+// -----
+
+func.func @shift_invalid_mesh_symbol(
+ %arg0 : tensor<4xi8>) -> tensor<4xi8> {
+ // expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+ %0 = mesh.shift %arg0 on @this_mesh_symbol_does_not_exist
+ shift_axis = 0 offset = -2
+ : tensor<4xi8> -> tensor<4xi8>
+ return %0 : tensor<4xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @shift_invalid_mesh_axis(
+ %arg0 : tensor<4xi8>) -> tensor<4xi8> {
+ // expected-error at +1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+ %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [2]
+ shift_axis = 2 offset = -2
+ : tensor<4xi8> -> tensor<4xi8>
+ return %0 : tensor<4xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @shift_duplicate_mesh_axis(
+ %arg0 : tensor<4xi8>) -> tensor<4xi8> {
+ // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+ %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 1, 0]
+ shift_axis = 0 offset = -2
+ : tensor<4xi8> -> tensor<4xi8>
+ return %0 : tensor<4xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @shift_invalid_tensor_dimension_size(
+ %arg0 : tensor<4xi8>) -> tensor<5xi8> {
+ // expected-error at +1 {{'mesh.shift' op requires the same shape for all operands and results}}
+ %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0]
+ shift_axis = 0 offset = 2
+ : tensor<4xi8> -> tensor<5xi8>
+ return %0 : tensor<5xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @shift_invalid_shift_axis(
+ %arg0 : tensor<4xi8>) -> tensor<4xi8> {
+ // expected-error at +1 {{Invalid shift axis 1. It must be one of the grouping mesh axes.}}
+ %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0]
+ shift_axis = 1 offset = 2
+ : tensor<4xi8> -> tensor<4xi8>
+ return %0 : tensor<4xi8>
+}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 78ce276a7b33a3..8f8e309d18f156 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -228,6 +228,159 @@ func.func @all_to_all_non_divisible_split_axis_size(
return %0 : tensor<?x12xi8>
}
+// CHECK-LABEL: func @broadcast_static_root
+func.func @broadcast_static_root(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+ %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
+ // CHECK-NEXT: mesh.broadcast %[[ARG]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-SAME: root = [0, 1]
+ // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<3x6xi8>
+ %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
+ root = [0, 1]
+ : (tensor<3x6xi8>) -> tensor<3x6xi8>
+ return %0 : tensor<3x6xi8>
+}
+
+// CHECK-LABEL: func @broadcast_dynamic_root
+func.func @broadcast_dynamic_root(
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<3x6xi8>
+ %arg0 : tensor<3x6xi8>,
+ // CHECK-SAME: %[[ARG1:.*]]: index
+ %arg1 : index
+ ) -> tensor<3x6xi8> {
+ // CHECK-NEXT: mesh.broadcast %[[ARG0]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-SAME: root = [1, %[[ARG1]]]
+ // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<3x6xi8>
+ %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
+ root = [1, %arg1]
+ : (tensor<3x6xi8>, index) -> tensor<3x6xi8>
+ return %0 : tensor<3x6xi8>
+}
+
+// CHECK-LABEL: func @gather_static_root
+func.func @gather_static_root(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
+ %arg0 : tensor<3x6xi8>) -> tensor<24x6xi8> {
+ // CHECK-NEXT: mesh.gather %[[ARG]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-SAME: gather_axis = 0
+ // CHECK-SAME: root = [0, 1]
+ // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<24x6xi8>
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
+ gather_axis = 0
+ root = [0, 1]
+ : (tensor<3x6xi8>) -> tensor<24x6xi8>
+ return %0 : tensor<24x6xi8>
+}
+
+// CHECK-LABEL: func @gather_dynamic_root
+func.func @gather_dynamic_root(
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<3x6xi8>
+ %arg0 : tensor<3x6xi8>,
+ // CHECK-SAME: %[[ARG1:.*]]: index
+ %arg1 : index
+ ) -> tensor<24x6xi8> {
+ // CHECK-NEXT: mesh.gather %[[ARG0]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-SAME: gather_axis = 0
+ // CHECK-SAME: root = [1, %[[ARG1]]]
+ // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<24x6xi8>
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
+ gather_axis = 0
+ root = [1, %arg1]
+ : (tensor<3x6xi8>, index) -> tensor<24x6xi8>
+ return %0 : tensor<24x6xi8>
+}
+
+// CHECK-LABEL: func @receive_static_source
+func.func @receive_static_source(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // CHECK-NEXT: mesh.recv %[[ARG]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-SAME: source = [0, 1]
+ // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
+ %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+ source = [0, 1]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @receive_dynamic_source
+func.func @receive_dynamic_source(
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
+ %arg0 : tensor<2xi8>,
+ // CHECK-SAME: %[[ARG1:.*]]: index
+ %arg1 : index
+ ) -> tensor<2xi8> {
+ // CHECK-NEXT: mesh.recv %[[ARG0]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-SAME: source = [1, %[[ARG1]]]
+ // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
+ %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+ source = [1, %arg1]
+ : (tensor<2xi8>, index) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @receive_no_source
+func.func @receive_no_source(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // CHECK-NEXT: mesh.recv %[[ARG]]
+ // CHECK-NOT: source
+ %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @reduce_static_root
+func.func @reduce_static_root(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // CHECK-NEXT: mesh.reduce %[[ARG]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-SAME: root = [0, 1]
+ // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
+ %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+ root = [0, 1]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @reduce_dynamic_root
+func.func @reduce_dynamic_root(
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
+ %arg0 : tensor<2xi8>,
+ // CHECK-SAME: %[[ARG1:.*]]: index
+ %arg1 : index
+ ) -> tensor<2xi8> {
+ // CHECK-NEXT: mesh.reduce %[[ARG0]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-SAME: root = [1, %[[ARG1]]]
+ // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
+ %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+ root = [1, %arg1]
+ : (tensor<2xi8>, index) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @reduce_different_return_element_type
+func.func @reduce_different_return_element_type(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0 : tensor<2xi8>) -> tensor<2xi16> {
+ // CHECK-NEXT: mesh.reduce %[[ARG]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-SAME: root = [0, 1]
+ // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi16>
+ %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
+ root = [0, 1]
+ : (tensor<2xi8>) -> tensor<2xi16>
+ return %0 : tensor<2xi16>
+}
+
// CHECK-LABEL: func @reduce_scatter_static_dimensions
func.func @reduce_scatter_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
@@ -252,3 +405,95 @@ func.func @reduce_scatter_dynamic_dimensions(
: tensor<?xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
+
+// CHECK-LABEL: func @scatter_static_dimensions
+func.func @scatter_static_dimensions(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+ %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
+ // CHECK-NEXT: mesh.scatter %[[ARG]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [2]
+ // CHECK-SAME: scatter_axis = 1 root = [1]
+ // CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32>
+ %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [2]
+ scatter_axis = 1 root = [1]
+ : (tensor<3x4xf32>) -> tensor<3x1xf32>
+ return %0 : tensor<3x1xf32>
+}
+
+// CHECK-LABEL: func @scatter_dynamic_dimensions
+func.func @scatter_dynamic_dimensions(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
+ %arg0 : tensor<?xf32>) -> tensor<?xf32> {
+ // CHECK-NEXT: mesh.scatter %[[ARG]]
+ // CHECK-SAME: on @mesh3 mesh_axes = [0, 1]
+ // CHECK-SAME: scatter_axis = 0 root = [1, 2]
+ // CHECK-SAME: : (tensor<?xf32>) -> tensor<?xf32>
+ %0 = mesh.scatter %arg0 on @mesh3 mesh_axes = [0, 1]
+ scatter_axis = 0 root = [1, 2]
+ : (tensor<?xf32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func @scatter_dynamic_root
+func.func @scatter_dynamic_root(
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<8xi8>
+ %arg0 : tensor<8xi8>,
+ // CHECK-SAME: %[[ARG1:.*]]: index
+ %arg1 : index
+ ) -> tensor<1xi8> {
+ // CHECK-NEXT: mesh.scatter %[[ARG0]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-SAME: scatter_axis = 0
+ // CHECK-SAME: root = [1, %[[ARG1]]]
+ // CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8>
+ %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 2]
+ scatter_axis = 0
+ root = [1, %arg1]
+ : (tensor<8xi8>, index) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// CHECK-LABEL: func @send_static_destination
+func.func @send_static_destination(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // CHECK-NEXT: mesh.send %[[ARG]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-SAME: destination = [0, 1]
+ // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
+ %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
+ destination = [0, 1]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @send_dynamic_destination
+func.func @send_dynamic_destination(
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
+ %arg0 : tensor<2xi8>,
+ // CHECK-SAME: %[[ARG1:.*]]: index
+ %arg1 : index
+ ) -> tensor<2xi8> {
+ // CHECK-NEXT: mesh.send %[[ARG0]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-SAME: destination = [1, %[[ARG1]]]
+ // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
+ %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
+ destination = [1, %arg1]
+ : (tensor<2xi8>, index) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// CHECK-LABEL: func @shift
+func.func @shift(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // CHECK-NEXT: mesh.shift %[[ARG]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
+ // CHECK-SAME: shift_axis = 2 offset = -2 rotate
+ // CHECK-SAME: : tensor<2xi8> -> tensor<2xi8>
+ %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 2]
+ shift_axis = 2 offset = -2 rotate
+ : tensor<2xi8> -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
>From 66d77515db00a3ad70569dc67edd3aec637d3313 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 14 Dec 2023 10:59:30 -0800
Subject: [PATCH 2/2] Fix typo
---
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 2 +-
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index fa6f9dbb79872f..784f3eb97763ad 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -680,7 +680,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
SameOperandsAndResultElementType,
SameOperandsAndResultShape
]> {
- let summary = "Sift over a device mesh.";
+ let summary = "Shift over a device mesh.";
let description = [{
Within each device group shift along mesh axis `shift_axis` by an offset
`offset`.
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 683b9adcd380a6..d27675564c6464 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -720,7 +720,7 @@ LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
// TODO: remove op when offset is 0 or if it is a rotate with and
- // offset % sift_axis_mesh_dim_size == 0.
+ // offset % shift_axis_mesh_dim_size == 0.
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list