[Mlir-commits] [mlir] [mlir][mesh] Handling changed halo region sizes during spmdization (PR #114238)

Matteo Franciolini llvmlistbot at llvm.org
Thu Oct 31 14:29:07 PDT 2024


================
@@ -192,33 +192,33 @@ template <typename InShape, typename MeshShape, typename SplitAxes,
           typename OutShape>
 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
                        const SplitAxes &splitAxes, OutShape &outShape,
-                       ArrayRef<int64_t> shardedDimsSizes = {},
+                       ArrayRef<int64_t> shardedDimsOffsets = {},
                        ArrayRef<int64_t> haloSizes = {}) {
   std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
             llvm::adl_begin(outShape));
 
-  if (!shardedDimsSizes.empty()) {
+  if (!shardedDimsOffsets.empty()) {
+    uint64_t pos = 0;
     for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
-      if (innerSplitAxes.empty()) {
-#ifndef NDEBUG
-        for (auto dimSz : shardedDimsSizes) {
-          auto inAxis = dimSz % inShape.size();
-          assert(inShape[inAxis] == dimSz || dimSz == ShapedType::kDynamic ||
-                 inShape[inAxis] == ShapedType::kDynamic);
-        }
-#endif // NDEBUG
-      } else {
-        // find sharded dims in sharded_dims_sizes with same static size on
-        // all devices. Use kDynamic for dimensions with dynamic or non-uniform
-        // sizes in sharded_dims_sizes.
-        auto sz = shardedDimsSizes[tensorAxis];
-        bool same = true;
-        for (size_t i = tensorAxis + inShape.size();
-             i < shardedDimsSizes.size(); i += inShape.size()) {
-          if (shardedDimsSizes[i] != sz) {
-            same = false;
-            break;
+      if (!innerSplitAxes.empty()) {
+        auto sz = shardedDimsOffsets[pos];
+        bool same = !ShapedType::isDynamicShape(meshShape);
+        if (same) {
+          // find sharded dims in shardedDimsOffsets with same static size on
----------------
mfrancio wrote:

```suggestion
          // Find sharded dims in shardedDimsOffsets with same static size on
```

https://github.com/llvm/llvm-project/pull/114238


More information about the Mlir-commits mailing list