[llvm] [mlir] [mlir][mesh] Add spmdization pass (PR #80518)

Chengji Yao via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 4 19:47:18 PST 2024


================
@@ -78,6 +92,35 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
   return res;
 }
 
+inline int64_t shardDimension(int64_t dim, int64_t shardCount) {
+  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+    return ShapedType::kDynamic;
+
+  assert(dim % shardCount == 0);
+  return ceilDiv(dim, shardCount);
+}
+
+inline int64_t unshardDimension(int64_t dim, int64_t shardCount) {
+  if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+    return ShapedType::kDynamic;
+
+  return dim * shardCount;
+}
+
+// Return the sharded shape `shape` acording ot sharding `sharding`.
----------------
yaochengji wrote:

```suggestion
// Return the sharded shape `shape` according ot sharding `sharding`.
```

https://github.com/llvm/llvm-project/pull/80518


More information about the llvm-commits mailing list