[Mlir-commits] [mlir] mlir::mesh::shardingOp adding shard-size control (PR #98145)

Frank Schlimbach llvmlistbot at llvm.org
Wed Jul 10 01:01:18 PDT 2024


================
@@ -33,6 +35,59 @@ using MeshAxesAttr = DenseI16ArrayAttr;
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
 
+namespace mlir {
+namespace mesh {
+
+class MeshSharding {
+private:
+  ::mlir::FlatSymbolRefAttr mesh;
+  SmallVector<MeshAxesAttr> split_axes;
+  SmallVector<MeshAxis> partial_axes;
+  ReductionKind partial_type;
+  SmallVector<int64_t> static_halo_sizes;
+  SmallVector<int64_t> static_sharded_dims_sizes;
+  SmallVector<Value> dynamic_halo_sizes;
+  SmallVector<Value> dynamic_sharded_dims_sizes;
+
+public:
+  MeshSharding() = default;
+  MeshSharding(Value rhs);
+  static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_,
+                          ArrayRef<MeshAxesAttr> split_axes_,
+                          ArrayRef<MeshAxis> partial_axes_ = {},
+                          ReductionKind partial_type_ = ReductionKind::Sum,
+                          ArrayRef<int64_t> static_halo_sizes_ = {},
+                          ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
+                          ArrayRef<Value> dynamic_halo_sizes_ = {},
+                          ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
+  ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
+  ::llvm::StringRef getMesh() const { return mesh.getValue(); }
+  ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
+  ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
+  ReductionKind getPartialType() const { return partial_type; }
+  ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
+  ArrayRef<int64_t> getStaticShardedDimsSizes() const {
+    return static_sharded_dims_sizes;
+  }
+  ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
+  ArrayRef<Value> getDynamicShardedDimsSizes() const {
+    return dynamic_sharded_dims_sizes;
+  }
+  operator bool() const { return (!mesh) == false; }
+  bool operator==(Value rhs) const;
+  bool operator!=(Value rhs) const;
+  bool operator==(const MeshSharding &rhs) const;
+  bool operator!=(const MeshSharding &rhs) const;
+  bool sameExceptConstraint(const MeshSharding &rhs) const;
+  bool sameConstraint(const MeshSharding &rhs) const;
----------------
fschlimb wrote:

Agree. I couldn't come up with a better name. If you have a suggestion, I'd be happy to change.

https://github.com/llvm/llvm-project/pull/98145


More information about the Mlir-commits mailing list