[Mlir-commits] [mlir] [mlir][mesh] Handling changed halo region sizes during spmdization (PR #114238)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 30 07:29:51 PDT 2024


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff cea9dd833cf800aeb005286b2667483cc5a8d688 13c590c515a34ae699baac8c386d6749aace778e --extensions h,cpp -- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h mlir/lib/Dialect/Mesh/IR/MeshOps.cpp mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 6a25f18e6a..1444eec0f7 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -489,18 +489,22 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
         tgtCoreOffs, coreShape, strides);
 
     // finally update the halo
-    auto updateHaloResult = builder.create<UpdateHaloOp>(
-        sourceShard.getLoc(),
-        RankedTensorType::get(outShape, sourceShard.getType().getElementType()),
-        sourceShard, initOprnd, mesh.getSymName(),
-        MeshAxesArrayAttr::get(builder.getContext(),
-                               sourceSharding.getSplitAxes()),
-        sourceSharding.getDynamicHaloSizes(),
-        sourceSharding.getStaticHaloSizes(),
-        targetSharding.getDynamicHaloSizes(),
-        targetSharding.getStaticHaloSizes()).getResult();
-    return std::make_tuple(
-        cast<TypedValue<ShapedType>>(updateHaloResult), targetSharding);
+    auto updateHaloResult =
+        builder
+            .create<UpdateHaloOp>(
+                sourceShard.getLoc(),
+                RankedTensorType::get(outShape,
+                                      sourceShard.getType().getElementType()),
+                sourceShard, initOprnd, mesh.getSymName(),
+                MeshAxesArrayAttr::get(builder.getContext(),
+                                       sourceSharding.getSplitAxes()),
+                sourceSharding.getDynamicHaloSizes(),
+                sourceSharding.getStaticHaloSizes(),
+                targetSharding.getDynamicHaloSizes(),
+                targetSharding.getStaticHaloSizes())
+            .getResult();
+    return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
+                           targetSharding);
   }
   return std::nullopt;
 }
@@ -725,8 +729,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
     targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
   } else {
     // Insert resharding.
-    TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
-        spmdizationMap.lookup(srcShardOp));
+    TypedValue<ShapedType> srcSpmdValue =
+        cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
     targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
                               symbolTableCollection);
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/114238


More information about the Mlir-commits mailing list