[Mlir-commits] [mlir] [MLIR][mesh] Mesh fixes (PR #124724)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 28 02:00:44 PST 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff e7de6036983641ccf0fb45afd3eb96ff962525aa d105b89098381f28213859f1890f1994332eacca --extensions cpp,h -- mlir/include/mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp mlir/test/Dialect/Arith/mesh-spmdize.cpp mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h mlir/include/mlir/InitAllDialects.h mlir/lib/Dialect/Mesh/IR/MeshOps.cpp mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 98165c658b..c789fc527e 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -775,7 +775,7 @@ MeshSharding::MeshSharding(Value rhs) {
auto splitAxes = shardingOp.getSplitAxes().getAxes();
auto partialAxes = shardingOp.getPartialAxes().value_or(ArrayRef<MeshAxis>());
// If splitAxes and partialAxes are empty, use "empty" constructor.
- if(splitAxes.empty() && partialAxes.empty()) {
+ if (splitAxes.empty() && partialAxes.empty()) {
*this = MeshSharding(shardingOp.getMeshAttr());
return;
}
@@ -796,7 +796,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
ArrayRef<Value> dynamic_halo_sizes_,
ArrayRef<Value> dynamic_sharded_dims_offsets_) {
MeshSharding res(mesh_);
- if(split_axes_.empty() && partial_axes_.empty()) {
+ if (split_axes_.empty() && partial_axes_.empty()) {
return res;
}
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index ec64728a87..f427d004c5 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -448,8 +448,8 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand,
AffineMap map) {
Value operandValue = opOperand.get();
auto operandType = dyn_cast<RankedTensorType>(operandValue.getType());
- if(!operandType) {
- if(operandValue.getType().isIntOrIndexOrFloat())
+ if (!operandType) {
+ if (operandValue.getType().isIntOrIndexOrFloat())
return MeshSharding();
return failure();
}
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index 6bb5d4a66f..b2acbf20b3 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -23,9 +23,10 @@ using namespace mlir::mesh;
namespace {
// Sharding of tensor.empty/tensor.splat
-template<typename OpTy>
+template <typename OpTy>
struct CreatorOpShardingInterface
- : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>, OpTy> {
+ : public ShardingInterface::ExternalModel<CreatorOpShardingInterface<OpTy>,
+ OpTy> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
auto ndims = mlir::cast<ShapedType>(op->getResult(0).getType()).getRank();
return SmallVector<utils::IteratorType>(ndims,
@@ -38,7 +39,9 @@ struct CreatorOpShardingInterface
auto type = dyn_cast<RankedTensorType>(val.getType());
if (!type)
return {};
- return SmallVector<AffineMap>(op->getNumOperands() + op->getNumResults(), {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
+ return SmallVector<AffineMap>(
+ op->getNumOperands() + op->getNumResults(),
+ {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)});
}
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
@@ -82,8 +85,7 @@ struct CreatorOpShardingInterface
newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]);
}
}
- newOp =
- builder.create<OpTy>(op->getLoc(), shardType, newOperands);
+ newOp = builder.create<OpTy>(op->getLoc(), shardType, newOperands);
spmdizationMap.map(op->getResult(0), newOp->getResult(0));
} else {
// `clone` will populate the mapping of old to new results.
@@ -100,7 +102,9 @@ void mlir::tensor::registerShardingInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
- EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(*ctx);
- SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(*ctx);
+ EmptyOp::template attachInterface<CreatorOpShardingInterface<EmptyOp>>(
+ *ctx);
+ SplatOp::template attachInterface<CreatorOpShardingInterface<SplatOp>>(
+ *ctx);
});
}
diff --git a/mlir/test/Dialect/Arith/mesh-spmdize.cpp b/mlir/test/Dialect/Arith/mesh-spmdize.cpp
index 0688e14b1c..89e1cd0307 100644
--- a/mlir/test/Dialect/Arith/mesh-spmdize.cpp
+++ b/mlir/test/Dialect/Arith/mesh-spmdize.cpp
@@ -4,14 +4,17 @@
mesh.mesh @mesh4x4(shape = 4x4)
-// CHECK-LABEL: func @test_spmdize_constant
-// CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> : tensor<256x1024xf32>
-// CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 : i32
-// CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
-func.func @test_spmdize_constant() -> (tensor<1024x1024xf32>) attributes {llvm.emit_c_interface} {
- %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32>
- %sharding_1 = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding
- %sharding_annotated_1 = mesh.shard %cst to %sharding_1 : tensor<1024x1024xf32>
- %ci = arith.constant 434 : i32
- return %sharding_annotated_1 : tensor<1024x1024xf32>
+ // CHECK-LABEL: func @test_spmdize_constant
+ // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<0.000000e+00> :
+ // tensor<256x1024xf32> CHECK-NEXT: [[vc434_i32:%.*]] = arith.constant 434 :
+ // i32 CHECK-NEXT: return [[vcst]] : tensor<256x1024xf32>
+ func.func @test_spmdize_constant()
+ ->(tensor<1024x1024xf32>)attributes{llvm.emit_c_interface} {
+ % cst =
+ arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> %
+ sharding_1 = mesh.sharding @mesh4x4 split_axes =
+ [[0]] : !mesh.sharding % sharding_annotated_1 =
+ mesh.shard % cst to % sharding_1 : tensor<1024x1024xf32> % ci =
+ arith.constant 434 : i32 return % sharding_annotated_1
+ : tensor<1024x1024xf32>
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/124724
More information about the Mlir-commits
mailing list