[Mlir-commits] [mlir] mlir::mesh::shardingOp adding shard-size control (PR #98145)
Frank Schlimbach
llvmlistbot at llvm.org
Wed Jul 10 01:46:48 PDT 2024
================
@@ -398,10 +478,160 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
std::mem_fn(&MeshAxesAttr::empty));
}
-bool MeshShardingAttr::operator!=(MeshShardingAttr rhs) const {
+bool MeshSharding::sameConstraint(const MeshSharding &rhs) const {
+ if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
+ !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(),
+ getStaticHaloSizes().end()),
+ llvm::make_range(rhs.getStaticHaloSizes().begin(),
+ rhs.getStaticHaloSizes().end()))) {
+ return false;
+ }
+ if (rhs.getStaticShardedDimsSizes().size() != getDynamicHaloSizes().size() ||
+ !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(),
+ getStaticShardedDimsSizes().end()),
+ llvm::make_range(rhs.getStaticShardedDimsSizes().begin(),
+ rhs.getStaticShardedDimsSizes().end()))) {
+ return false;
+ }
+ if (rhs.getDynamicHaloSizes().size() != getStaticShardedDimsSizes().size() ||
+ !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
+ getDynamicHaloSizes().end()),
+ llvm::make_range(rhs.getDynamicHaloSizes().begin(),
+ rhs.getDynamicHaloSizes().end()))) {
+ return false;
+ }
+ if (rhs.getDynamicShardedDimsSizes().size() !=
+ getDynamicShardedDimsSizes().size() ||
+ !llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(),
+ getDynamicShardedDimsSizes().end()),
+ llvm::make_range(rhs.getDynamicShardedDimsSizes().begin(),
+ rhs.getDynamicShardedDimsSizes().end()))) {
+ return false;
+ }
+ return true;
+}
+
+bool MeshSharding::operator==(Value rhs) const {
+ return sameExceptConstraint(rhs) && sameConstraint(rhs);
+}
+
+bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
+
+bool MeshSharding::operator==(const MeshSharding &rhs) const {
+ return sameExceptConstraint(rhs) && sameConstraint(rhs);
+}
+
+bool MeshSharding::operator!=(const MeshSharding &rhs) const {
return !(*this == rhs);
}
+MeshSharding::MeshSharding(Value rhs) {
+ auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
+ assert(shardingOp && "expected sharding op");
+ *this = get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(),
+ shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
+ shardingOp.getPartialType().value_or(ReductionKind::Sum),
+ shardingOp.getStaticHaloSizes(),
+ shardingOp.getStaticShardedDimsSizes(),
+ SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
+ SmallVector<Value>(shardingOp.getDynamicShardedDimsSizes()));
+}
+
+MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
+ ArrayRef<MeshAxesAttr> split_axes_,
+ ArrayRef<MeshAxis> partial_axes_,
+ ReductionKind partial_type_,
+ ArrayRef<int64_t> static_halo_sizes_,
+ ArrayRef<int64_t> static_sharded_dims_sizes_,
+ ArrayRef<Value> dynamic_halo_sizes_,
+ ArrayRef<Value> dynamic_sharded_dims_sizes_) {
+ MeshSharding res;
+ res.mesh = mesh_;
+ res.split_axes.resize(split_axes_.size());
+ for (auto [i, axis] : llvm::enumerate(split_axes_)) {
+ res.split_axes[i] =
+ MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
+ }
+
+ auto do_copy = [&](auto src, auto &dst) {
+ dst.resize(src.size());
+ for (auto [i, v] : llvm::enumerate(src)) {
+ dst[i] = v;
+ }
+ };
----------------
fschlimb wrote:
done.
https://github.com/llvm/llvm-project/pull/98145
More information about the Mlir-commits
mailing list