[Mlir-commits] [mlir] [mlir][mesh] Add verification and canonicalization for the some collectives (PR #74905)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 8 16:59:14 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Boian Petkantchin (sogartar)
<details>
<summary>Changes</summary>
Add verification and canonicalization for
broadcast, gather, recv, reduce, scatter, send and shift.
The canonicalizations only remove trivial collectives with empty mesh_axes attrubutes.
---
Patch is 38.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74905.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+7)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+155-21)
- (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+79)
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+395)
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+245)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index e6cdba949b1721..fa6f9dbb79872f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -392,6 +392,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
@@ -454,6 +455,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
@@ -477,6 +479,7 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
(`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
@@ -517,6 +520,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
@@ -645,6 +649,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
`root` `=` custom<DynamicIndexList>($root_dynamic, $root)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
@@ -668,6 +673,7 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
`destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
attr-dict `:` functional-type(operands, results)
}];
+ let hasCanonicalizer = 1;
}
def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
@@ -728,6 +734,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
(`rotate` $rotate^)?
attr-dict `:` type($input) `->` type($result)
}];
+ let hasCanonicalizer = 1;
}
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 1ba95f21ec7f3d..683b9adcd380a6 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
@@ -231,6 +232,32 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
} // namespace
+static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
+ ArrayRef<int64_t> device,
+ Operation::operand_range deviceDynamic,
+ ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ if (device.size() != meshAxes.size()) {
+ return emitError(loc) << "In-group device \"" << deviceName
+ << "\" has unexpected multi-index size "
+ << device.size() << ". Expected " << meshAxes.size()
+ << ".";
+ }
+
+ for (size_t i = 0; i < device.size(); ++i) {
+ if (!ShapedType::isDynamic(device[i]) &&
+ !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
+ meshShape[meshAxes[i]] <= device[i]) {
+ return emitError(loc)
+ << "Out of bounds coordinate " << i << " for in-group device \""
+ << deviceName << "\"."
+ << " Got " << device[i] << ", but expected value in the range [0, "
+ << (meshShape[meshAxes[i]] - 1) << "].";
+ }
+ }
+ return success();
+}
+
static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTable) {
mesh::ClusterOp mesh =
@@ -338,7 +365,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
return success();
}
-static LogicalResult verifyAllGatherOperandAndResultShape(
+static LogicalResult verifyGatherOperandAndResultShape(
Value operand, Value result, int64_t gatherAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
auto resultRank = result.getType().template cast<ShapedType>().getRank();
@@ -410,7 +437,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
return success();
}
-static LogicalResult verifyReduceScatterOperandAndResultShape(
+static LogicalResult verifyScatterOperandAndResultShape(
Value operand, Value result, int64_t scatterAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
ShapedType operandType = operand.getType().cast<ShapedType>();
@@ -459,9 +486,9 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
auto gatherAxis = getGatherAxis().getSExtValue();
- return verifyAllGatherOperandAndResultShape(getOperand(), getResult(),
- gatherAxis, getMeshAxes(),
- mesh.value().canonicalDimSizes());
+ return verifyGatherOperandAndResultShape(getOperand(), getResult(),
+ gatherAxis, getMeshAxes(),
+ mesh.value().canonicalDimSizes());
}
void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -510,8 +537,22 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
LogicalResult
BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ return success();
+}
+
+void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -519,17 +560,48 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ auto gatherAxis = getGatherAxis().getSExtValue();
+ return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
+ getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+}
+
+void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
}
//===----------------------------------------------------------------------===//
-// mesh.receive op
+// mesh.recv op
//===----------------------------------------------------------------------===//
LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (getSource() && failed(verifyInGroupDevice(
+ getLoc(), getSourceAttrName(), getSource().value(),
+ getSourceDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+ return success();
+}
+
+void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -537,8 +609,22 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ return success();
+}
+
+void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -552,7 +638,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
- return verifyReduceScatterOperandAndResultShape(
+ return verifyScatterOperandAndResultShape(
getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
mesh.value().canonicalDimSizes());
}
@@ -567,8 +653,25 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
//===----------------------------------------------------------------------===//
LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+ getRootDynamic(), getMeshAxes(), meshShape))) {
+ return failure();
+ }
+
+ auto scatterAxis = getScatterAxis().getSExtValue();
+ return verifyScatterOperandAndResultShape(getInput(), getResult(),
+ scatterAxis, getMeshAxes(),
+ mesh.value().canonicalDimSizes());
+}
+
+void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -576,8 +679,22 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ auto meshShape = mesh.value().canonicalDimSizes();
+ if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
+ getDestination(), getDestinationDynamic(),
+ getMeshAxes(), meshShape))) {
+ return failure();
+ }
+ return success();
+}
+
+void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
}
//===----------------------------------------------------------------------===//
@@ -585,8 +702,25 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- // TODO
- return failure();
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+
+ auto meshAxes = getMeshAxes();
+ auto shiftAxis = getShiftAxis().getZExtValue();
+ if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
+ return emitError() << "Invalid shift axis " << shiftAxis
+ << ". It must be one of the grouping mesh axes.";
+ }
+
+ return success();
+}
+
+void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ // TODO: remove op when offset is 0 or if it is a rotate with and
+ // offset % sift_axis_mesh_dim_size == 0.
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index baee9faa645c93..0a00ab41268d01 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -63,6 +63,58 @@ func.func @all_gather_empty_mesh_axes(
return %0 : tensor<4xf32>
}
+// CHECK-LABEL: func @broadcast_empty_mesh_axes
+func.func @broadcast_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.broadcast
+ %0 = mesh.broadcast %arg0 on @mesh0
+ mesh_axes = []
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @gather_empty_mesh_axes
+func.func @gather_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.gather
+ %0 = mesh.gather %arg0 on @mesh0
+ mesh_axes = []
+ gather_axis = 0
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @receive_empty_mesh_axes
+func.func @receive_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.recv
+ %0 = mesh.recv %arg0 on @mesh0
+ mesh_axes = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_empty_mesh_axes
+func.func @reduce_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.reduce
+ %0 = mesh.reduce %arg0 on @mesh0
+ mesh_axes = []
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
func.func @reduce_scatter_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
@@ -99,3 +151,30 @@ func.func @reduce_scatter_default_reduction(
: tensor<4xf32> -> tensor<2xf64>
return %0 : tensor<2xf64>
}
+
+// CHECK-LABEL: func @scatter_empty_mesh_axes
+func.func @scatter_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.scatter
+ %0 = mesh.scatter %arg0 on @mesh0
+ mesh_axes = []
+ scatter_axis = 0
+ root = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @send_empty_mesh_axes
+func.func @send_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.send
+ %0 = mesh.send %arg0 on @mesh0
+ mesh_axes = []
+ destination = []
+ : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index a26e3950186e95..03994f8f011e1f 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -298,6 +298,221 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
// -----
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_root_dimension_out_of_bounds(
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
+ %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ root = [3]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_root_wrong_number_dimensions(
+ %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+ // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
+ %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ root = [2, 2]
+ : (tensor<2xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_different_input_and_result_type(
+ %arg0 : tensor<2xi8>) -> tensor<2xi16> {
+ // expected-error at +1 {{'mesh.broadcast' op failed to verify that all of {input, result} have same element type}}
+ %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+ root = [2]
+ : (tensor<2xi8>) -> tensor<2xi16>
+ return %0 : tensor<2xi16>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_wrong_return_element_type(
+ %arg0 : tensor<1xf32>) -> tensor<1xi8> {
+ // expected-error at +1 {{'mesh.gather' op failed to verify that all of {input, result} have same element type}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+ : (tensor<1xf32>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_non_gather_axis_dimension_size(
+ %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+ : (tensor<3x4xf32>) -> tensor<3x5xf32>
+ return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x2)
+
+func.func @gather_invalid_gather_axis_dimension_size(
+ %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 root = [0]
+ : (tensor<3x4xf32>) -> tensor<3x5xf32>
+ return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_gather_axis_dynamic_dimension(
+ %arg0 : tensor<?xf32>) -> tensor<3xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+ %0 = mesh.gather %arg0 on @mesh0 gather_axis = 0 root = []
+ : (tensor<?xf32>) -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_gather_axis(
+ %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+ // expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 root = [0]
+ : (tensor<3xf32>) -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_negative_gather_axis(
+ %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+ // expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
+ %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 root = [0]
+ : (tensor<3xf32>) -> tensor<3xf32>
+ return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh....
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/74905
More information about the Mlir-commits
mailing list