[mlir] [llvm] [mlir][mesh] Add spmdization pass (PR #80518)

Chengji Yao via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 4 19:47:11 PST 2024


================
@@ -78,6 +92,35 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
   return res;
 }
 
+inline int64_t shardDimension(int64_t dim, int64_t shardCount) {
+  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+    return ShapedType::kDynamic;
+
+  assert(dim % shardCount == 0);
+  return ceilDiv(dim, shardCount);
+}
+
+inline int64_t unshardDimension(int64_t dim, int64_t shardCount) {
----------------
yaochengji wrote:

I'm not sure `unshard` is a suitable name or not. Or call it `gatherDimension`?

https://github.com/llvm/llvm-project/pull/80518


More information about the llvm-commits mailing list