[Mlir-commits] [mlir] [mlir][mesh] Handling changed halo region sizes during spmdization (PR #114238)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 30 07:29:51 PDT 2024
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff cea9dd833cf800aeb005286b2667483cc5a8d688 13c590c515a34ae699baac8c386d6749aace778e --extensions h,cpp -- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h mlir/lib/Dialect/Mesh/IR/MeshOps.cpp mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 6a25f18e6a..1444eec0f7 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -489,18 +489,22 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
tgtCoreOffs, coreShape, strides);
// finally update the halo
- auto updateHaloResult = builder.create<UpdateHaloOp>(
- sourceShard.getLoc(),
- RankedTensorType::get(outShape, sourceShard.getType().getElementType()),
- sourceShard, initOprnd, mesh.getSymName(),
- MeshAxesArrayAttr::get(builder.getContext(),
- sourceSharding.getSplitAxes()),
- sourceSharding.getDynamicHaloSizes(),
- sourceSharding.getStaticHaloSizes(),
- targetSharding.getDynamicHaloSizes(),
- targetSharding.getStaticHaloSizes()).getResult();
- return std::make_tuple(
- cast<TypedValue<ShapedType>>(updateHaloResult), targetSharding);
+ auto updateHaloResult =
+ builder
+ .create<UpdateHaloOp>(
+ sourceShard.getLoc(),
+ RankedTensorType::get(outShape,
+ sourceShard.getType().getElementType()),
+ sourceShard, initOprnd, mesh.getSymName(),
+ MeshAxesArrayAttr::get(builder.getContext(),
+ sourceSharding.getSplitAxes()),
+ sourceSharding.getDynamicHaloSizes(),
+ sourceSharding.getStaticHaloSizes(),
+ targetSharding.getDynamicHaloSizes(),
+ targetSharding.getStaticHaloSizes())
+ .getResult();
+ return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult),
+ targetSharding);
}
return std::nullopt;
}
@@ -725,8 +729,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
} else {
// Insert resharding.
- TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
- spmdizationMap.lookup(srcShardOp));
+ TypedValue<ShapedType> srcSpmdValue =
+ cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
symbolTableCollection);
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/114238
More information about the Mlir-commits
mailing list