[Mlir-commits] [mlir] [mlir][mesh] Add collective communication operations (PR #71960)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 10 09:18:51 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Boian Petkantchin (sogartar)
<details>
<summary>Changes</summary>
Add all-gather, all-reduce, all-to-all and reduce-scatter. These operations have device mesh semantics.
I have not included ops like reduce, gather, send and recv to see first if reviewers notice any systemic issues. Also this PR is already big enough.
---
Patch is 43.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71960.diff
8 Files Affected:
- (added) mlir/docs/Dialects/Mesh.md (+34)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+5-3)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+2)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+216)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+417)
- (added) mlir/test/Dialect/Mesh/canonicalization.mlir (+72)
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+240)
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+119)
``````````diff
diff --git a/mlir/docs/Dialects/Mesh.md b/mlir/docs/Dialects/Mesh.md
new file mode 100644
index 000000000000000..6dd4f79022061ee
--- /dev/null
+++ b/mlir/docs/Dialects/Mesh.md
@@ -0,0 +1,34 @@
+# 'mesh' Dialect
+
+The `mesh` dialect contains a set of attributes, operations and interfaces that
+are useful for representing sharding and communication on a device mesh
+cluster.
+
+[TOC]
+
+## Collective Communication Operations
+There are a number of operations in the Mesh dialect to facilitate
+communication between devices in a mesh.
+It is assumed that the user is familiar with collective operations.
+[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
+explanation.
+The main addition is that the collectives in this dialect have mesh
+semantics.
+The operation attributes `mesh` and `mesh_axes` specifies a set of device mesh
+axes that partition the devices into disjoint groups.
+The collective operation is performed between devices in the same group.
+Devices that have the same coordinates outside of axes `mesh_axes` are in the
+same group.
+For example if we have a device mesh of size `2x3x4x5` and the partition mesh
+axes set is `{0, 1}` then devices are partitioned into the groups
+`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
+Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
+Device (1, 0, 2, 4) will be in another group.
+
+## Operations
+
+[include "Dialects/MeshOps.md"]
+
+## Attributes
+
+[include "Dialects/MeshAttributes.md"]
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index a91ef569347bff1..9d39b1b3329fb4b 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -23,9 +23,7 @@ def Mesh_Dialect : Dialect {
let cppNamespace = "::mlir::mesh";
let description = [{
- The `mesh` dialect contains a set of attributes, operations, interfaces that
- are useful for representing sharding and communication on device mesh
- cluster.
+ See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
}];
let dependentDialects = [
@@ -49,6 +47,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}
+def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 05eba66a89949b6..7698d60813a8f10 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -10,9 +10,11 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include <algorithm>
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index a8aa0a694bee29f..15354babe870599 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -13,6 +13,8 @@ include "mlir/Dialect/Mesh/IR/MeshBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
+include "mlir/IR/CommonAttrConstraints.td"
+include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -77,6 +79,15 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
attr-dict
}];
+ let extraClassDeclaration = [{
+ ::mlir::SmallVector<int64_t> canonicalDimSizes();
+
+ template <typename OutIt>
+ void canonicalDimSizes(OutIt outIt) {
+ std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
+ std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
+ }
+ }];
let hasVerifier = 1;
}
@@ -171,4 +182,209 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+class Mesh_CollectiveCommunicationOpBase<
+ string mnemonic, list<Trait> traits = []> :
+ Mesh_Op<mnemonic,
+ !listconcat(traits,
+ [SymbolUserOpInterface])> {
+ let assemblyFormat = "$input attr-dict `:` type($input) `->` type($result)";
+ code extraClassDeclarationBase = [{
+ ::mlir::LogicalResult verifySymbolUses(
+ ::mlir::SymbolTableCollection &symbolTable);
+ }];
+}
+
+def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank
+ ]> {
+ let summary = "All-gather over a device mesh.";
+ let description = [{
+ Gathers along the `gather_axis` tensor axis.
+ The order of input tensors in the resulting tensor is the same as the
+ order of the corresponding devices' multi-index in the mesh.
+
+ Example:
+ ```mlir
+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
+ ...
+ %1 = mesh.all_gather %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1>, gather_axis = 1 : index
+ } : tensor<2x2xi8> -> tensor<2x4xi8>
+ ```
+ Input:
+ ```
+ +-------+-------+
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
+ | 3 4 | 7 8 |
+ +-------+-------+
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
+ | 11 12 | 15 16 |
+ +-------+-------+
+ ```
+ Result:
+ ```
+ +-------------+
+ | 1 2 5 6 | <- devices (0, 0) and (0, 1)
+ | 3 4 7 8 |
+ +-------------+
+ | 9 10 13 14 | <- devices (1, 0) and (1, 1)
+ | 11 12 15 16 |
+ +-------------+
+ ```
+ }];
+ let arguments = (ins
+ AnyNon0RankedTensor:$input,
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+ APIntAttr:$gather_axis
+ );
+ let results = (outs
+ AnyNon0RankedTensor:$result
+ );
+ let hasCanonicalizer = 1;
+ let hasVerifier = 1;
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
+ SameOperandsAndResultShape]> {
+ let summary = "All-reduce over a device mesh.";
+ let description = [{
+ The accumulation element type is specified by the result type and
+ it does not need to match the input element type.
+ The input element is converted to the result element type before
+ performing the reduction.
+
+ Attributes:
+ `reduction`: Indicates the reduction method.
+
+ Example:
+ ```
+ %1 = mesh.all_reduce %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1, 0>, reduction = #mesh.partial<max>
+ } : tensor<3x4xf32> -> tensor<3x4xf64>
+ ```
+ }];
+ let arguments = (ins
+ AnyRankedTensor:$input,
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+ DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
+ );
+ let results = (outs
+ AnyRankedTensor:$result
+ );
+ let hasCanonicalizer = 1;
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank]> {
+ let summary = "All-to-all over a device mesh.";
+ let description = [{
+ Performs an all-to-all on tensor pieces split along `split_axis`.
+ The resulting pieces are concatenated along `concat_axis` on ech device.
+ Example:
+ ```
+ mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
+ ...
+ %1 = mesh.all_to_all %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 0>, split_axis = 0, concat_axis = 0
+ } : tensor<3x6xi8> -> tensor<3x6xi8>
+ ```
+ Input:
+ ```
+ device device device
+ (0) (1) (2)
+ +-------+-------+-------+
+ | 11 12 | 21 22 | 31 32 |
+ | 13 14 | 23 24 | 33 34 |
+ | 15 16 | 25 26 | 35 36 |
+ +-------+-------+-------+
+ ```
+ Result:
+ ```
+ device device device
+ (0) (1) (2)
+ +-------+-------+-------+
+ | 11 12 | 13 14 | 15 16 |
+ | 21 22 | 23 24 | 25 26 |
+ | 31 32 | 33 34 | 35 36 |
+ +-------+-------+-------+
+ ```
+ }];
+ let arguments = (ins
+ AnyNon0RankedTensor:$input,
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+ APIntAttr:$split_axis,
+ APIntAttr:$concat_axis
+ );
+ let results = (outs
+ AnyNon0RankedTensor:$result
+ );
+ let hasCanonicalizer = 1;
+ let hasVerifier = 1;
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
+def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
+ SameOperandsAndResultRank]> {
+ let summary = "Reduce-scatter over a device mesh.";
+ let description = [{
+ After the reduction scatters the result within each device group.
+ The tensor is split along `scatter_axis` and the pieces distributed
+ across the device group.
+ Example:
+ ```
+ mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
+ ...
+ %1 = mesh.reduce_scatter %0 {
+ mesh = @mesh0, mesh_axes = array<i16: 1>, reduction = #mesh.partial<max>, scatter_axis = 0
+ } : tensor<3x4xf32> -> tensor<1x4xf64>
+ ```
+ Input:
+ ```
+ +-------+-------+
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
+ | 3 4 | 7 8 |
+ +-------+-------+
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
+ | 11 12 | 15 16 |
+ +-------+-------+
+ ```
+ Result:
+ ```
+ +-------+
+ | 6 8 | <- devices (0, 0)
+ +-------+
+ | 10 12 | <- devices (0, 1)
+ +-------+
+ | 22 24 | <- devices (1, 0)
+ +-------+
+ | 26 28 | <- devices (1, 1)
+ +-------+
+ ```
+ }];
+ let arguments = (ins
+ AnyNon0RankedTensor:$input,
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedOptionalAttr<DenseI16ArrayAttr, "{}">:$mesh_axes,
+ DefaultValuedOptionalAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
+ APIntAttr:$scatter_axis
+ );
+ let results = (outs
+ AnyRankedTensor:$result
+ );
+ let hasCanonicalizer = 1;
+ let hasVerifier = 1;
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 588704f24574f90..6efc4c4ecc326ad 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -8,10 +8,26 @@
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
+#include <functional>
+#include <iterator>
+#include <numeric>
+#include <optional>
+#include <string>
+#include <utility>
#define DEBUG_TYPE "mesh-ops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
@@ -21,6 +37,60 @@ using namespace mlir::mesh;
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
+namespace {
+
+template <typename It>
+It canonicalizeSetAsArray(It begin, It end) {
+ std::sort(begin, end);
+ return std::unique(begin, end);
+}
+
+template <typename R>
+auto canonicalizeSetAsArray(R &&range) {
+ return canonicalizeSetAsArray(adl_begin(range), adl_end(range));
+}
+
+template <typename T>
+SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
+ auto newEnd = canonicalizeSetAsArray(vec);
+ vec.resize(newEnd - vec.begin());
+ return vec;
+}
+
+template <typename DimSize>
+bool isMeshDimensionDynamic(DimSize size) {
+ return size <= DimSize(0);
+}
+
+using MeshAxis = int16_t;
+
+struct DimensionSize {
+ static DimensionSize dynamic() { return DimensionSize(ShapedType::kDynamic); }
+ DimensionSize(int64_t val) : val(val) {}
+ int64_t value() const { return val; }
+ operator int64_t() const { return val; }
+ bool isDynamic() const { return ShapedType::isDynamic(val); }
+
+private:
+ int64_t val;
+};
+
+DimensionSize operator/(DimensionSize lhs, DimensionSize rhs) {
+ if (lhs.isDynamic() || rhs.isDynamic()) {
+ return DimensionSize::dynamic();
+ }
+ return lhs.value() / rhs.value();
+}
+
+DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
+ if (lhs.isDynamic() || rhs.isDynamic()) {
+ return DimensionSize::dynamic();
+ }
+ return lhs.value() * rhs.value();
+}
+
+} // namespace
+
//===----------------------------------------------------------------------===//
// Mesh dialect
//===----------------------------------------------------------------------===//
@@ -96,6 +166,12 @@ LogicalResult ClusterOp::verify() {
return success();
}
+SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
+ SmallVector<int64_t> result;
+ canonicalDimSizes(std::back_inserter(result));
+ return result;
+}
+
//===----------------------------------------------------------------------===//
// mesh.shard op
//===----------------------------------------------------------------------===//
@@ -129,6 +205,347 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// collective communication ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+std::optional<DenseI16ArrayAttr>
+canonicalizeAxesSetAttribute(DenseI16ArrayAttr attr) {
+ if (!attr) {
+ return std::nullopt;
+ }
+ SmallVector<int16_t> axes = llvm::to_vector(attr.asArrayRef());
+ canonicalizeSetAsVector(axes);
+ if (axes.empty()) {
+ return std::nullopt;
+ }
+ return DenseI16ArrayAttr::get(attr.getContext(), axes);
+}
+
+template <typename Op>
+struct AxesSetCanonicalizationPattern : OpRewritePattern<Op> {
+ AxesSetCanonicalizationPattern(MLIRContext *context, StringRef axisSetAttr)
+ : OpRewritePattern<Op>(context), axisSetAttr(axisSetAttr) {}
+ LogicalResult matchAndRewrite(Op op,
+ PatternRewriter &rewriter) const override {
+ auto canonicalMeshAxesAttr = canonicalizeAxesSetAttribute(
+ op->template getAttrOfType<DenseI16ArrayAttr>(axisSetAttr));
+ if (!canonicalMeshAxesAttr) {
+ op->removeAttr(axisSetAttr);
+ } else {
+ op->setAttr(axisSetAttr, canonicalMeshAxesAttr.value());
+ }
+ return success();
+ }
+
+ std::string axisSetAttr;
+};
+
+template <typename Op>
+void populateMeshAxesSetCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<AxesSetCanonicalizationPattern<Op>>(context, "mesh_axes");
+}
+
+template <typename Op>
+LogicalResult verifyMeshSymbolUses(Op op, SymbolTableCollection &symbolTable) {
+ FlatSymbolRefAttr symbolAttr = op.getMeshAttr();
+ if (!symbolAttr) {
+ return op.emitError() << "Unspecified \"mesh\" symbol attribute.";
+ }
+ SymbolTableCollection symbolTableCollection;
+ mesh::ClusterOp mesh =
+ symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), symbolAttr);
+ if (!mesh) {
+ return op.emitError() << "Undefined required mesh symbol \""
+ << symbolAttr.getValue() << "\".";
+ }
+ DenseI16ArrayAttr meshAxes = op.getMeshAxesAttr();
+ if (!meshAxes) {
+ return success();
+ }
+ MeshAxis rank = mesh.getRank();
+ for (auto axis : meshAxes.asArrayRef()) {
+ if (axis >= rank || axis < 0) {
+ return op.emitError()
+ << "0-based mesh axis index " << axis
+ << " is out of bounds. The referenced mesh \""
+ << symbolAttr.getValue() << "\" is of rank " << rank << ".";
+ }
+ }
+
+ return success();
+}
+
+template <typename It>
+auto product(It begin, It end) {
+ using ElementType = std::decay_t<decltype(*begin)>;
+ return std::accumulate(begin, end, ElementType(1),
+ std::multiplies<ElementType>());
+}
+
+template <typename R>
+auto product(R &&range) {
+ return product(adl_begin(range), adl_end(range));
+}
+
+int64_t collectiveDeviceGroupSize(ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ int64_t res = 1;
+ for (MeshAxis axis = 0; axis < MeshAxis(meshShape.size()); ++axis) {
+ if (llvm::find(meshAxes, axis) == meshAxes.end()) {
+ continue;
+ }
+ if (isMeshDimensionDynamic(meshShape[axis])) {
+ return ShapedType::kDynamic;
+ }
+ res *= meshShape[axis];
+ }
+ return res;
+}
+
+LogicalResult verifyDimensionCompatibility(Location loc,
+ int64_t expectedDimSize,
+ int64_t resultDimSize,
+ int64_t resultAxis) {
+ if (!ShapedType::isDynamic(resultDimSize) &&
+ expectedDimSize != resultDimSize) {
+ return emitError(loc) << "Dimension size mismatch for result axis "
+ << resultAxis << ". Expected "
+ << (ShapedType::isDynamic(expectedDimSize)
+ ? Twine("dynamic")
+ : Twine(expectedDimSize))
+ << ", but got " << resultDimSize << ".";
+ }
+
+ return success();
+}
+
+LogicalResult verifyGatherOperandAndResultShape(Value operand, Value result,
+ int64_t gatherAxis,
+ ArrayRef<MeshAxis> meshAxes,
+ ArrayRef<int64_t> meshShape) {
+ ShapedType operandType = operand.getType().cast<ShapedType>();
+ ShapedType resultType = result.getType().cast<ShapedType>();
+ auto deviceGroupSize =
+ DimensionSize(collectiveDeviceGroupSize(meshAxes, meshShape));
+ for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
+ auto operandDimSize = DimensionSize(operandType.getDimSize(axis));
+ auto resultDimSize = DimensionSize(resultType.getDimSize(axis));
+ auto expectedResultDimSize =
+ axis == gatherAxis ? deviceGroupSize * operandDimSize : operandDimSize;
+ if (failed(verifyDimensionCompatibility(
+ result.getLoc(), expectedResultDimSize, resultDimSize, axis))) {
+ return failure();
+ }
+ }
+ return success();
+}
+
+template <typename Op>
+FailureOr<ClusterOp> getMesh(Op op) {
+ SymbolTableCollection symbolTableCollection;
+ if (failed(verifyMeshSymbolUses(op, symbolTableCollection))) {
+ // We need to check the symbol here since this runs before
+ // SymbolUserOpInterface.
+ return failure();
+ }
+ return symbolTableCollection.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op.getOperation(), op.getMeshAttr());
+}
+
+template <typename Op>
+LogicalResult verifyGather(Op op) {
+ auto rank = op.getResult().getType().template cast<ShapedType>().getRank();
+ auto gatherAxis = op.getGatherAxis().getSExtValue();
+ if (gatherAxis < 0 || gatherAxis >= rank) {
+ return op.emitError() << "Gather axis " << gatherAxis
+ << " is out of bounds [0, " << rank << ").";
+ }
+
+ auto mesh = getMesh(op);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return verifyGatherOperandAndResultShape(op.getOperand(), op.get...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/71960
More information about the Mlir-commits
mailing list