[Mlir-commits] [mlir] mlir::mesh::shardingOp adding shard-size control (PR #98145)
Boian Petkantchin
llvmlistbot at llvm.org
Tue Jul 9 18:55:36 PDT 2024
================
@@ -33,6 +35,59 @@ using MeshAxesAttr = DenseI16ArrayAttr;
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
+namespace mlir {
+namespace mesh {
+
+class MeshSharding {
+private:
+ ::mlir::FlatSymbolRefAttr mesh;
+ SmallVector<MeshAxesAttr> split_axes;
+ SmallVector<MeshAxis> partial_axes;
+ ReductionKind partial_type;
+ SmallVector<int64_t> static_halo_sizes;
+ SmallVector<int64_t> static_sharded_dims_sizes;
+ SmallVector<Value> dynamic_halo_sizes;
+ SmallVector<Value> dynamic_sharded_dims_sizes;
+
+public:
+ MeshSharding() = default;
+ MeshSharding(Value rhs);
+ static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
+ ArrayRef<MeshAxesAttr> split_axes_,
+ ArrayRef<MeshAxis> partial_axes_ = {},
+ ReductionKind partial_type_ = ReductionKind::Sum,
+ ArrayRef<int64_t> static_halo_sizes_ = {},
+ ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+ ArrayRef<Value> dynamic_halo_sizes_ = {},
+ ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+ ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
+ ::llvm::StringRef getMesh() const { return mesh.getValue(); }
+ ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
+ ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
+ ReductionKind getPartialType() const { return partial_type; }
+ ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
+ ArrayRef<int64_t> getStaticShardedDimsSizes() const {
+ return static_sharded_dims_sizes;
+ }
+ ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
+ ArrayRef<Value> getDynamicShardedDimsSizes() const {
+ return dynamic_sharded_dims_sizes;
+ }
+ operator bool() const { return (!mesh) == false; }
+ bool operator==(Value rhs) const;
+ bool operator!=(Value rhs) const;
+ bool operator==(const MeshSharding &rhs) const;
+ bool operator!=(const MeshSharding &rhs) const;
+ bool sameExceptConstraint(const MeshSharding &rhs) const;
+ bool sameConstraint(const MeshSharding &rhs) const;
----------------
sogartar wrote:
Constraint does not seem like a correct name here.
It seems like there is a concept hierarchy that we have to define. Like replicated sharding, split sharding, equisized split sharding, (equisized) split sharding with halo, etc.
I don't know what are the correct concepts and they will probably become evident when we write a lot of sharding transformations for various operations.
https://github.com/llvm/llvm-project/pull/98145
More information about the Mlir-commits
mailing list