[Mlir-commits] [mlir] [mlir][mesh]fixes for 0d tensors (PR #132948)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 25 08:52:22 PDT 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff feecb201ab041dbcf8266960aba4b252a789bcd4 f54a37cb31285aa94fae0c60e0c7560841d1708d --extensions cpp,h -- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h mlir/lib/Dialect/Mesh/IR/MeshOps.cpp mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 8aaa070411..7b3107a6e6 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -716,8 +716,8 @@ void mesh::spmdizeTriviallyShardableOperation(
// Set the result types to the sharded counterparts.
for (auto [oldResult, newResult, sharding] :
llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
- newResult.setType(
- shardType(newResult.getType(),
- getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
+ newResult.setType(shardType(
+ newResult.getType(),
+ getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
}
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index b6cb06ae31..69a80b1a05 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -689,33 +689,33 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
static std::vector<MeshSharding> getResultShardings(Operation &op) {
std::vector<MeshSharding> res;
res.reserve(op.getNumResults());
- llvm::transform(op.getResults(), std::back_inserter(res),
- [&op](OpResult result) {
- if (!result.hasOneUse() || result.use_empty()) {
- return MeshSharding();
- }
- TypedValue<RankedTensorType> rankedTensor =
- dyn_cast<TypedValue<RankedTensorType>>(result);
- if (!rankedTensor) {
- return MeshSharding();
- }
- Operation *userOp = *result.getUsers().begin();
- ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
- if (shardOp) {
- return MeshSharding(shardOp.getSharding());
- }
- if (rankedTensor.getType().getRank() == 0) {
- // This is a 0d tensor result without explicit sharding.
- // Find mesh symbol from operands, if any.
- // Shardings without mesh are not always fully supported yet.
- for (auto operand: op.getOperands()) {
- if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
- return MeshSharding(sharding.getMeshAttr());
- }
- }
- }
- return MeshSharding();
- });
+ llvm::transform(
+ op.getResults(), std::back_inserter(res), [&op](OpResult result) {
+ if (!result.hasOneUse() || result.use_empty()) {
+ return MeshSharding();
+ }
+ TypedValue<RankedTensorType> rankedTensor =
+ dyn_cast<TypedValue<RankedTensorType>>(result);
+ if (!rankedTensor) {
+ return MeshSharding();
+ }
+ Operation *userOp = *result.getUsers().begin();
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
+ if (shardOp) {
+ return MeshSharding(shardOp.getSharding());
+ }
+ if (rankedTensor.getType().getRank() == 0) {
+ // This is a 0d tensor result without explicit sharding.
+ // Find mesh symbol from operands, if any.
+ // Shardings without mesh are not always fully supported yet.
+ for (auto operand : op.getOperands()) {
+ if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
+ return MeshSharding(sharding.getMeshAttr());
+ }
+ }
+ }
+ return MeshSharding();
+ });
return res;
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/132948
More information about the Mlir-commits
mailing list