[Mlir-commits] [mlir] [MLIR][mesh] Mesh fixes (PR #124724)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 28 02:00:44 PST 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff e7de6036983641ccf0fb45afd3eb96ff962525aa d105b89098381f28213859f1890f1994332eacca --extensions cpp,h -- mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp mlir/test/Dialect/Arith/mesh-spmdize.cpp mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h mlir/include/mlir/InitAllDialects.h mlir/lib/Dialect/Mesh/IR/MeshOps.cpp mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 98165c658b..c789fc527e 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -775,7 +775,7 @@ MeshSharding::MeshSharding(Value rhs) {
   auto splitAxes = shardingOp.getSplitAxes().getAxes();
   auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
   // If splitAxes and partialAxes are empty, use "empty" constructor.
-  if(splitAxes.empty() && partialAxes.empty()) {
+  if (splitAxes.empty() && partialAxes.empty()) {
     *this = MeshSharding(shardingOp.getMeshAttr());
     return;
   }
@@ -796,7 +796,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
                                ArrayRef<Value> dynamic_halo_sizes_,
                                ArrayRef<Value> dynamic_sharded_dims_offsets_) {
   MeshSharding res(mesh_);
-  if(split_axes_.empty() && partial_axes_.empty()) {
+  if (split_axes_.empty() && partial_axes_.empty()) {
     return res;
   }
 
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index ec64728a87..f427d004c5 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -448,8 +448,8 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
                                            AffineMap map) {
   Value operandValue = opOperand.get();
   auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
-  if(!operandType) {
-    if(operandValue.getType().isIntOrIndexOrFloat())
+  if (!operandType) {
+    if (operandValue.getType().isIntOrIndexOrFloat())
       return MeshSharding();
     return failure();
   }
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index 6bb5d4a66f..b2acbf20b3 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -23,9 +23,10 @@ using namespace mlir::mesh;
 namespace {
 
 // Sharding of tensor.empty/tensor.splat
-template<typename OpTy>
+template <typename OpTy>
 struct CreatorOpShardingInterface
-    : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>, OpTy> {
+    : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
+                                              OpTy> {
   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
     auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
     return SmallVector<utils::IteratorType>(ndims,
@@ -38,7 +39,9 @@ struct CreatorOpShardingInterface
     auto type = dyn_cast<RankedTensorType>(val.getType());
     if (!type)
       return {};
-    return SmallVector<AffineMap>(op->getNumOperands() + op->getNumResults(), {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
+    return SmallVector<AffineMap>(
+        op->getNumOperands() + op->getNumResults(),
+        {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
   }
 
   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
@@ -82,8 +85,7 @@ struct CreatorOpShardingInterface
           newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
         }
       }
-      newOp =
-          builder.create<OpTy>(op->getLoc(), shardType, newOperands);
+      newOp = builder.create<OpTy>(op->getLoc(), shardType, newOperands);
       spmdizationMap.map(op->getResult(0), newOp->getResult(0));
     } else {
       // `clone` will populate the mapping of old to new results.
@@ -100,7 +102,9 @@ void mlir::tensor::registerShardingInterfaceExternalModels(
     DialectRegistry &registry) {
 
   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
-    EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(*ctx);
-    SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(*ctx);
+    EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
+        *ctx);
+    SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
+        *ctx);
   });
 }
diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.cpp b/mlir/test/Dialect/Arith/mesh-spmdize.cpp
index 0688e14b1c..89e1cd0307 100644
--- a/mlir/test/Dialect/Arith/mesh-spmdize.cpp
+++ b/mlir/test/Dialect/Arith/mesh-spmdize.cpp
@@ -4,14 +4,17 @@
 
 mesh.mesh @mesh4x4(shape = 4x4)
 
-// CHECK-LABEL: func @test_spmdize_constant
-// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<256x1024xf32>
-// CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : i32
-// CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
-func.func @test_spmdize_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
-    %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
-    %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
-    %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
-    %ci = arith.constant 434 : i32
-    return %sharding_annotated_1 : tensor<1024x1024xf32>
+    // CHECK-LABEL: func @test_spmdize_constant
+    // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> :
+    // tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 :
+    // i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
+    func.func @test_spmdize_constant()
+        ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} {
+  % cst =
+      arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> %
+      sharding_1 = mesh.sharding @mesh4x4 split_axes =
+          [[0]] : !mesh.sharding % sharding_annotated_1 =
+              mesh.shard % cst to % sharding_1 : tensor<1024x1024xf32> % ci =
+                  arith.constant 434 : i32 return % sharding_annotated_1
+      : tensor<1024x1024xf32>
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/124724


More information about the Mlir-commits mailing list