[Mlir-commits] [mlir] [mlir][mesh]fixes for 0d tensors (PR #132948)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 25 08:52:22 PDT 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff feecb201ab041dbcf8266960aba4b252a789bcd4 f54a37cb31285aa94fae0c60e0c7560841d1708d --extensions cpp,h -- mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h mlir/lib/Dialect/Mesh/IR/MeshOps.cpp mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 8aaa070411..7b3107a6e6 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -716,8 +716,8 @@ void mesh::spmdizeTriviallyShardableOperation(
   // Set the result types to the sharded counterparts.
   for (auto [oldResult, newResult, sharding] :
        llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
-    newResult.setType(
-        shardType(newResult.getType(),
-                  getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
+    newResult.setType(shardType(
+        newResult.getType(),
+        getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
   }
 }
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index b6cb06ae31..69a80b1a05 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -689,33 +689,33 @@ static std::vector<MeshSharding> getOperandShardings(Operation &op) {
 static std::vector<MeshSharding> getResultShardings(Operation &op) {
   std::vector<MeshSharding> res;
   res.reserve(op.getNumResults());
-  llvm::transform(op.getResults(), std::back_inserter(res),
-                  [&op](OpResult result) {
-                    if (!result.hasOneUse() || result.use_empty()) {
-                      return MeshSharding();
-                    }
-                    TypedValue<RankedTensorType> rankedTensor =
-                        dyn_cast<TypedValue<RankedTensorType>>(result);
-                    if (!rankedTensor) {
-                      return MeshSharding();
-                    }
-                    Operation *userOp = *result.getUsers().begin();
-                    ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
-                    if (shardOp) {
-                      return MeshSharding(shardOp.getSharding());
-                    }
-                    if (rankedTensor.getType().getRank() == 0) {
-                      // This is a 0d tensor result without explicit sharding.
-                      // Find mesh symbol from operands, if any.
-                      // Shardings without mesh are not always fully supported yet.
-                      for (auto operand: op.getOperands()) {
-                        if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
-                          return MeshSharding(sharding.getMeshAttr());
-                        }
-                      }
-                    }
-                    return MeshSharding();
-                  });
+  llvm::transform(
+      op.getResults(), std::back_inserter(res), [&op](OpResult result) {
+        if (!result.hasOneUse() || result.use_empty()) {
+          return MeshSharding();
+        }
+        TypedValue<RankedTensorType> rankedTensor =
+            dyn_cast<TypedValue<RankedTensorType>>(result);
+        if (!rankedTensor) {
+          return MeshSharding();
+        }
+        Operation *userOp = *result.getUsers().begin();
+        ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
+        if (shardOp) {
+          return MeshSharding(shardOp.getSharding());
+        }
+        if (rankedTensor.getType().getRank() == 0) {
+          // This is a 0d tensor result without explicit sharding.
+          // Find mesh symbol from operands, if any.
+          // Shardings without mesh are not always fully supported yet.
+          for (auto operand : op.getOperands()) {
+            if (auto sharding = operand.getDefiningOp<ShardingOp>()) {
+              return MeshSharding(sharding.getMeshAttr());
+            }
+          }
+        }
+        return MeshSharding();
+      });
   return res;
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/132948


More information about the Mlir-commits mailing list