[Mlir-commits] [mlir] mlir::mesh::shardingOp adding shard-size control (PR #98145)

Boian Petkantchin llvmlistbot at llvm.org
Tue Jul 9 18:55:36 PDT 2024


================
@@ -398,10 +478,160 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
                       std::mem_fn(&MeshAxesAttr::empty));
 }
 
-bool MeshShardingAttr::operator!=(MeshShardingAttr rhs) const {
+bool MeshSharding::sameConstraint(const MeshSharding &rhs) const {
+  if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() ||
+      !llvm::equal(llvm::make_range(getStaticHaloSizes().begin(),
+                                    getStaticHaloSizes().end()),
+                   llvm::make_range(rhs.getStaticHaloSizes().begin(),
+                                    rhs.getStaticHaloSizes().end()))) {
+    return false;
+  }
+  if (rhs.getStaticShardedDimsSizes().size() != getDynamicHaloSizes().size() ||
+      !llvm::equal(llvm::make_range(getStaticShardedDimsSizes().begin(),
+                                    getStaticShardedDimsSizes().end()),
+                   llvm::make_range(rhs.getStaticShardedDimsSizes().begin(),
+                                    rhs.getStaticShardedDimsSizes().end()))) {
+    return false;
+  }
+  if (rhs.getDynamicHaloSizes().size() != getStaticShardedDimsSizes().size() ||
+      !llvm::equal(llvm::make_range(getDynamicHaloSizes().begin(),
+                                    getDynamicHaloSizes().end()),
+                   llvm::make_range(rhs.getDynamicHaloSizes().begin(),
+                                    rhs.getDynamicHaloSizes().end()))) {
+    return false;
+  }
+  if (rhs.getDynamicShardedDimsSizes().size() !=
+          getDynamicShardedDimsSizes().size() ||
+      !llvm::equal(llvm::make_range(getDynamicShardedDimsSizes().begin(),
+                                    getDynamicShardedDimsSizes().end()),
+                   llvm::make_range(rhs.getDynamicShardedDimsSizes().begin(),
+                                    rhs.getDynamicShardedDimsSizes().end()))) {
+    return false;
+  }
+  return true;
+}
+
+bool MeshSharding::operator==(Value rhs) const {
+  return sameExceptConstraint(rhs) && sameConstraint(rhs);
+}
+
+bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); }
+
+bool MeshSharding::operator==(const MeshSharding &rhs) const {
+  return sameExceptConstraint(rhs) && sameConstraint(rhs);
+}
+
+bool MeshSharding::operator!=(const MeshSharding &rhs) const {
   return !(*this == rhs);
 }
 
+MeshSharding::MeshSharding(Value rhs) {
+  auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
+  assert(shardingOp && "expected sharding op");
+  *this = get(shardingOp.getMeshAttr(), shardingOp.getSplitAxes().getAxes(),
+              shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>()),
+              shardingOp.getPartialType().value_or(ReductionKind::Sum),
+              shardingOp.getStaticHaloSizes(),
+              shardingOp.getStaticShardedDimsSizes(),
+              SmallVector<Value>(shardingOp.getDynamicHaloSizes()),
+              SmallVector<Value>(shardingOp.getDynamicShardedDimsSizes()));
+}
+
+MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
+                               ArrayRef<MeshAxesAttr> split_axes_,
+                               ArrayRef<MeshAxis> partial_axes_,
+                               ReductionKind partial_type_,
+                               ArrayRef<int64_t> static_halo_sizes_,
+                               ArrayRef<int64_t> static_sharded_dims_sizes_,
+                               ArrayRef<Value> dynamic_halo_sizes_,
+                               ArrayRef<Value> dynamic_sharded_dims_sizes_) {
+  MeshSharding res;
+  res.mesh = mesh_;
+  res.split_axes.resize(split_axes_.size());
+  for (auto [i, axis] : llvm::enumerate(split_axes_)) {
+    res.split_axes[i] =
+        MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef());
+  }
+
+  auto do_copy = [&](auto src, auto &dst) {
+    dst.resize(src.size());
+    for (auto [i, v] : llvm::enumerate(src)) {
+      dst[i] = v;
+    }
+  };
----------------
sogartar wrote:

There is [copy](https://github.com/llvm/llvm-project/blob/e5aa7999bf1bc9a6c7a03bbc65036325a4ec5503/llvm/include/llvm/ADT/STLExtras.h#L1824) in STL extras.

https://github.com/llvm/llvm-project/pull/98145


More information about the Mlir-commits mailing list