[Mlir-commits] [mlir] mlir::mesh::shardingOp adding shard-size control (PR #98145)

Boian Petkantchin llvmlistbot at llvm.org
Tue Jul 9 18:55:36 PDT 2024


================
@@ -33,6 +35,59 @@ using MeshAxesAttr = DenseI16ArrayAttr;
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
 
+namespace mlir {
+namespace mesh {
+
+class MeshSharding {
+private:
+  ::mlir::FlatSymbolRefAttr mesh;
+  SmallVector<MeshAxesAttr> split_axes;
+  SmallVector<MeshAxis> partial_axes;
+  ReductionKind partial_type;
+  SmallVector<int64_t> static_halo_sizes;
+  SmallVector<int64_t> static_sharded_dims_sizes;
+  SmallVector<Value> dynamic_halo_sizes;
+  SmallVector<Value> dynamic_sharded_dims_sizes;
+
+public:
+  MeshSharding() = default;
+  MeshSharding(Value rhs);
+  static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
+                          ArrayRef<MeshAxesAttr> split_axes_,
+                          ArrayRef<MeshAxis> partial_axes_ = {},
+                          ReductionKind partial_type_ = ReductionKind::Sum,
+                          ArrayRef<int64_t> static_halo_sizes_ = {},
+                          ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+                          ArrayRef<Value> dynamic_halo_sizes_ = {},
+                          ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+  ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
+  ::llvm::StringRef getMesh() const { return mesh.getValue(); }
+  ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
+  ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
+  ReductionKind getPartialType() const { return partial_type; }
+  ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
+  ArrayRef<int64_t> getStaticShardedDimsSizes() const {
+    return static_sharded_dims_sizes;
+  }
+  ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
+  ArrayRef<Value> getDynamicShardedDimsSizes() const {
+    return dynamic_sharded_dims_sizes;
+  }
+  operator bool() const { return (!mesh) == false; }
+  bool operator==(Value rhs) const;
+  bool operator!=(Value rhs) const;
+  bool operator==(const MeshSharding &rhs) const;
+  bool operator!=(const MeshSharding &rhs) const;
+  bool sameExceptConstraint(const MeshSharding &rhs) const;
+  bool sameConstraint(const MeshSharding &rhs) const;
----------------
sogartar wrote:

Constraint does not seem like a correct name here.
It seems like there is a concept hierarchy that we have to define. Like replicated sharding, split sharding, equisized split sharding, (equisized) split sharding with halo, etc.
I don't know what are the correct concepts and they will probably become evident when we write a lot of sharding transformations for various operations.


https://github.com/llvm/llvm-project/pull/98145


More information about the Mlir-commits mailing list