[Mlir-commits] [mlir] [mlir][mesh] Add verification and canonicalization for the some collectives (PR #74905)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 8 16:59:14 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Boian Petkantchin (sogartar)

<details>
<summary>Changes</summary>

Add verification and canonicalization for
broadcast, gather, recv, reduce, scatter, send and shift.

The canonicalizations only remove trivial collectives with empty mesh_axes attrubutes.

---

Patch is 38.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74905.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+7) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+155-21) 
- (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+79) 
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+395) 
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+245) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index e6cdba949b1721..fa6f9dbb79872f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -392,6 +392,7 @@ def Mesh_BroadcastOp : Mesh_CollectiveCommunicationOpBase<"broadcast", [
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
@@ -454,6 +455,7 @@ def Mesh_GatherOp : Mesh_CollectiveCommunicationOpBase<"gather", [
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
@@ -477,6 +479,7 @@ def Mesh_RecvOp : Mesh_CollectiveCommunicationOpBase<"recv", [
     (`source` `=` custom<DynamicIndexList>($source_dynamic, $source)^)?
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
@@ -517,6 +520,7 @@ def Mesh_ReduceOp : Mesh_CollectiveCommunicationOpBase<"reduce", [
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
@@ -645,6 +649,7 @@ def Mesh_ScatterOp : Mesh_CollectiveCommunicationOpBase<"scatter", [
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
@@ -668,6 +673,7 @@ def Mesh_SendOp : Mesh_CollectiveCommunicationOpBase<"send", [
     `destination` `=` custom<DynamicIndexList>($destination_dynamic, $destination)
     attr-dict `:` functional-type(operands, results)
   }];
+  let hasCanonicalizer = 1;
 }
 
 def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
@@ -728,6 +734,7 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
     (`rotate` $rotate^)?
     attr-dict `:` type($input) `->` type($result)
   }];
+  let hasCanonicalizer = 1;
 }
 
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 1ba95f21ec7f3d..683b9adcd380a6 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Diagnostics.h"
@@ -231,6 +232,32 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> {
 
 } // namespace
 
+static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
+                                         ArrayRef<int64_t> device,
+                                         Operation::operand_range deviceDynamic,
+                                         ArrayRef<MeshAxis> meshAxes,
+                                         ArrayRef<int64_t> meshShape) {
+  if (device.size() != meshAxes.size()) {
+    return emitError(loc) << "In-group device \"" << deviceName
+                          << "\" has unexpected multi-index size "
+                          << device.size() << ". Expected " << meshAxes.size()
+                          << ".";
+  }
+
+  for (size_t i = 0; i < device.size(); ++i) {
+    if (!ShapedType::isDynamic(device[i]) &&
+        !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
+        meshShape[meshAxes[i]] <= device[i]) {
+      return emitError(loc)
+             << "Out of bounds coordinate " << i << " for in-group device \""
+             << deviceName << "\"."
+             << " Got " << device[i] << ", but expected value in the range [0, "
+             << (meshShape[meshAxes[i]] - 1) << "].";
+    }
+  }
+  return success();
+}
+
 static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
                                     SymbolTableCollection &symbolTable) {
   mesh::ClusterOp mesh =
@@ -338,7 +365,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
   return success();
 }
 
-static LogicalResult verifyAllGatherOperandAndResultShape(
+static LogicalResult verifyGatherOperandAndResultShape(
     Value operand, Value result, int64_t gatherAxis,
     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
   auto resultRank = result.getType().template cast<ShapedType>().getRank();
@@ -410,7 +437,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
   return success();
 }
 
-static LogicalResult verifyReduceScatterOperandAndResultShape(
+static LogicalResult verifyScatterOperandAndResultShape(
     Value operand, Value result, int64_t scatterAxis,
     ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
   ShapedType operandType = operand.getType().cast<ShapedType>();
@@ -459,9 +486,9 @@ AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return failure();
   }
   auto gatherAxis = getGatherAxis().getSExtValue();
-  return verifyAllGatherOperandAndResultShape(getOperand(), getResult(),
-                                              gatherAxis, getMeshAxes(),
-                                              mesh.value().canonicalDimSizes());
+  return verifyGatherOperandAndResultShape(getOperand(), getResult(),
+                                           gatherAxis, getMeshAxes(),
+                                           mesh.value().canonicalDimSizes());
 }
 
 void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -510,8 +537,22 @@ void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 
 LogicalResult
 BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  return success();
+}
+
+void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                              MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -519,17 +560,48 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  auto gatherAxis = getGatherAxis().getSExtValue();
+  return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis,
+                                           getMeshAxes(),
+                                           mesh.value().canonicalDimSizes());
+}
+
+void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                           MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
-// mesh.receive op
+// mesh.recv op
 //===----------------------------------------------------------------------===//
 
 LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (getSource() && failed(verifyInGroupDevice(
+                         getLoc(), getSourceAttrName(), getSource().value(),
+                         getSourceDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+  return success();
+}
+
+void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                         MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -537,8 +609,22 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  return success();
+}
+
+void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                           MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -552,7 +638,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
     return failure();
   }
 
-  return verifyReduceScatterOperandAndResultShape(
+  return verifyScatterOperandAndResultShape(
       getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
       mesh.value().canonicalDimSizes());
 }
@@ -567,8 +653,25 @@ void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 //===----------------------------------------------------------------------===//
 
 LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(),
+                                 getRootDynamic(), getMeshAxes(), meshShape))) {
+    return failure();
+  }
+
+  auto scatterAxis = getScatterAxis().getSExtValue();
+  return verifyScatterOperandAndResultShape(getInput(), getResult(),
+                                            scatterAxis, getMeshAxes(),
+                                            mesh.value().canonicalDimSizes());
+}
+
+void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                            MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -576,8 +679,22 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+  auto meshShape = mesh.value().canonicalDimSizes();
+  if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(),
+                                 getDestination(), getDestinationDynamic(),
+                                 getMeshAxes(), meshShape))) {
+    return failure();
+  }
+  return success();
+}
+
+void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                         MLIRContext *context) {
+  patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -585,8 +702,25 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  // TODO
-  return failure();
+  auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+  if (failed(mesh)) {
+    return failure();
+  }
+
+  auto meshAxes = getMeshAxes();
+  auto shiftAxis = getShiftAxis().getZExtValue();
+  if (llvm::find(meshAxes, shiftAxis) == meshAxes.end()) {
+    return emitError() << "Invalid shift axis " << shiftAxis
+                       << ". It must be one of the grouping mesh axes.";
+  }
+
+  return success();
+}
+
+void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                          MLIRContext *context) {
+  // TODO: remove op when offset is 0 or if it is a rotate with and
+  // offset % sift_axis_mesh_dim_size == 0.
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index baee9faa645c93..0a00ab41268d01 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -63,6 +63,58 @@ func.func @all_gather_empty_mesh_axes(
   return %0 : tensor<4xf32>
 }
 
+// CHECK-LABEL: func @broadcast_empty_mesh_axes
+func.func @broadcast_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.broadcast
+  %0 = mesh.broadcast %arg0 on @mesh0
+    mesh_axes = []
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @gather_empty_mesh_axes
+func.func @gather_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.gather
+  %0 = mesh.gather %arg0 on @mesh0
+    mesh_axes = []
+    gather_axis = 0
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @receive_empty_mesh_axes
+func.func @receive_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.recv
+  %0 = mesh.recv %arg0 on @mesh0
+    mesh_axes = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @reduce_empty_mesh_axes
+func.func @reduce_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.reduce
+  %0 = mesh.reduce %arg0 on @mesh0
+    mesh_axes = []
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
 // CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
 func.func @reduce_scatter_empty_mesh_axes(
 // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
@@ -99,3 +151,30 @@ func.func @reduce_scatter_default_reduction(
     : tensor<4xf32> -> tensor<2xf64>
   return %0 : tensor<2xf64>
 }
+
+// CHECK-LABEL: func @scatter_empty_mesh_axes
+func.func @scatter_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.scatter
+  %0 = mesh.scatter %arg0 on @mesh0
+    mesh_axes = []
+    scatter_axis = 0
+    root = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: func @send_empty_mesh_axes
+func.func @send_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.send
+  %0 = mesh.send %arg0 on @mesh0
+    mesh_axes = []
+    destination = []
+    : (tensor<4xf32>) -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+  return %0 : tensor<4xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index a26e3950186e95..03994f8f011e1f 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -298,6 +298,221 @@ func.func @all_to_all_invalid_non_dynamic_result_split_dimension_size(
 
 // -----
 
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_root_dimension_out_of_bounds(
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // expected-error at +1 {{Out of bounds coordinate 0 for in-group device "root". Got 3, but expected value in the range [0, 2].}}
+  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+    root = [3]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_root_wrong_number_dimensions(
+    %arg0 : tensor<2xi8>) -> tensor<2xi8> {
+  // expected-error at +1 {{In-group device "root" has unexpected multi-index size 2. Expected 1.}}
+  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+    root = [2, 2]
+    : (tensor<2xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 3x?)
+
+func.func @broadcast_different_input_and_result_type(
+    %arg0 : tensor<2xi8>) -> tensor<2xi16> {
+  // expected-error at +1 {{'mesh.broadcast' op failed to verify that all of {input, result} have same element type}}
+  %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0]
+    root = [2]
+    : (tensor<2xi8>) -> tensor<2xi16>
+  return %0 : tensor<2xi16>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_wrong_return_element_type(
+    %arg0 : tensor<1xf32>) -> tensor<1xi8> {
+  // expected-error at +1 {{'mesh.gather' op failed to verify that all of {input, result} have same element type}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+    : (tensor<1xf32>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_non_gather_axis_dimension_size(
+    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 4, but got 5.}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 0 root = [0]
+    : (tensor<3x4xf32>) -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 1x2)
+
+func.func @gather_invalid_gather_axis_dimension_size(
+    %arg0 : tensor<3x4xf32>) -> tensor<3x5xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 1. Expected 8, but got 5.}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [1] gather_axis = 1 root = [0]
+    : (tensor<3x4xf32>) -> tensor<3x5xf32>
+  return %0 : tensor<3x5xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_gather_axis_dynamic_dimension(
+    %arg0 : tensor<?xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 3.}}
+  %0 = mesh.gather %arg0 on @mesh0 gather_axis = 0 root = []
+    : (tensor<?xf32>) -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_gather_axis(
+    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Gather axis 1 is out of bounds [0, 1).}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = 1 root = [0]
+    : (tensor<3xf32>) -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 1, dim_sizes = 1)
+
+func.func @gather_invalid_negative_gather_axis(
+    %arg0 : tensor<3xf32>) -> tensor<3xf32> {
+  // expected-error at +1 {{Gather axis -1 is out of bounds [0, 1).}}
+  %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0] gather_axis = -1 root = [0]
+    : (tensor<3xf32>) -> tensor<3xf32>
+  return %0 : tensor<3xf32>
+}
+
+// -----
+
+mesh....
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/74905


More information about the Mlir-commits mailing list