[Mlir-commits] [mlir] [mlir][mesh] resubmitting #144079 (PR #145897)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 26 07:04:07 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Frank Schlimbach (fschlimb)
<details>
<summary>Changes</summary>
#<!-- -->144079 introduced a test with an uninitialized access
Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140
and got reverted b0ef9125347cbaea031273feb72ac0d6bc74ddee
This PR is an exact copy of #<!-- -->144079 plus a trivial fix (96c8525c82c5474d8522a973553e85a2033bee5b).
---
Full diff: https://github.com/llvm/llvm-project/pull/145897.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+1-4)
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h (+12)
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (+15)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+15-12)
- (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+29-13)
- (added) mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir (+26)
- (added) mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir (+27)
- (added) mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir (+49)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 1dc178586e918..7213fde45c695 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -44,7 +44,7 @@ class MeshSharding {
::mlir::FlatSymbolRefAttr mesh;
SmallVector<MeshAxesAttr> split_axes;
SmallVector<MeshAxis> partial_axes;
- ReductionKind partial_type;
+ ReductionKind partial_type = ReductionKind::Sum;
SmallVector<int64_t> static_halo_sizes;
SmallVector<int64_t> static_sharded_dims_offsets;
SmallVector<Value> dynamic_halo_sizes;
@@ -206,9 +206,6 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
// Use newShardOp if it is not null. Otherwise create a new one.
// May insert resharding if required.
// Potentially updates newShardOp.
-void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
- OpOperand &operand, OpBuilder &builder,
- ShardOp &newShardOp);
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
OpBuilder &builder);
void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index 83399d10beaae..a2424d43a8ba9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -19,6 +19,18 @@ class FuncOp;
namespace mesh {
+/// This enum controls the traversal order for the sharding propagation.
+enum class TraversalOrder {
+ /// Forward traversal.
+ Forward,
+ /// Backward traversal.
+ Backward,
+ /// Forward then backward traversal.
+ ForwardBackward,
+ /// Backward then forward traversal.
+ BackwardForward
+};
+
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 06ebf151e7d64..11ec7e78cd5e6 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -24,6 +24,21 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
operation, and the operations themselves are added with sharding option
attributes.
}];
+ let options = [
+ Option<"traversal", "traversal",
+ "mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
+ "Traversal order to use for sharding propagation:",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
+ "Forward only traversal."),
+ clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
+ "backward only traversal."),
+ clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
+ "forward-backward traversal."),
+ clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
+ "backward-forward traversal.")
+ )}]>,
+ ];
let dependentDialects = [
"mesh::MeshDialect"
];
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 0a01aaf776e7d..b8cc91da722f0 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -298,13 +298,12 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
return type;
}
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
- OpOperand &operand,
- OpBuilder &builder,
- ShardOp &newShardOp) {
+static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
+ Value &operandValue,
+ Operation *operandOp,
+ OpBuilder &builder,
+ ShardOp &newShardOp) {
OpBuilder::InsertionGuard insertionGuard(builder);
- Value operandValue = operand.get();
- Operation *operandOp = operand.getOwner();
builder.setInsertionPointAfterValue(operandValue);
ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
if (shardOp && sharding == shardOp.getSharding() &&
@@ -323,9 +322,8 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
/*annotate_for_users*/ false);
}
- IRRewriter rewriter(builder);
- rewriter.replaceUsesWithIf(
- operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
+ operandValue.replaceUsesWithIf(
+ newShardOp, [operandOp, operandValue](OpOperand &use) {
return use.getOwner() == operandOp && use.get() == operandValue;
});
@@ -336,15 +334,20 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
newShardOp.getSharding(),
/*annotate_for_users*/ true);
- rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+ newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
}
void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
OpResult result,
OpBuilder &builder) {
ShardOp newShardOp;
- for (auto &use : llvm::make_early_inc_range(result.getUses())) {
- maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
+ SmallVector<std::pair<Value, Operation *>> uses;
+ for (auto &use : result.getUses()) {
+ uses.emplace_back(use.get(), use.getOwner());
+ }
+ for (auto &[operandValue, operandOp] : uses) {
+ maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
+ builder, newShardOp);
}
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 4452dd65fce9d..6751fafaf1776 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -362,6 +362,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
//===----------------------------------------------------------------------===//
struct ShardingPropagation
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
+
+ using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
+
void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();
MLIRContext *ctx = funcOp.getContext();
@@ -382,18 +385,31 @@ struct ShardingPropagation
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
});
- // 1. propagate in reversed order
- for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
- if (failed(visitOp(&op, builder)))
- return signalPassFailure();
-
- LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
- << funcOp << "\n");
- LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
-
- // 2. propagate in original order
- for (Operation &op : llvm::make_early_inc_range(block))
- if (failed(visitOp(&op, builder)))
- return signalPassFailure();
+ auto traverse = [&](auto &&range, OpBuilder &builder,
+ const char *order) -> bool {
+ for (Operation &op : range) {
+ if (failed(visitOp(&op, builder))) {
+ signalPassFailure();
+ return true;
+ }
+ }
+ LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
+ << funcOp << "\n");
+ LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+ return false;
+ };
+
+ // 1. Propagate in reversed order.
+ if (traversal == TraversalOrder::Backward ||
+ traversal == TraversalOrder::BackwardForward)
+ traverse(llvm::reverse(block), builder, "backward");
+
+ // 2. Propagate in original order.
+ if (traversal != TraversalOrder::Backward)
+ traverse(block, builder, "forward");
+
+ // 3. Propagate in backward order if needed.
+ if (traversal == TraversalOrder::ForwardBackward)
+ traverse(llvm::reverse(block), builder, "backward");
}
};
diff --git a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
new file mode 100644
index 0000000000000..4223d01d65111
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=backward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+ mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+ func.func @test_forward() -> tensor<6x6xi32> {
+ %c1_i32 = arith.constant 1 : i32
+ // CHECK: tensor.empty()
+ %0 = tensor.empty() : tensor<6x6xi32>
+ %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+ // CHECK-COUNT-2: mesh.shard
+ %sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32>
+ %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32>
+ // CHECK: tensor.empty()
+ // CHECK-NOT: mesh.shard @
+ %2 = tensor.empty() : tensor<6x6xi32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1
+ : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
+ ^bb0(%in: i32, %in_2: i32, %out: i32):
+ %9 = arith.addi %in, %in_2 : i32
+ linalg.yield %9 : i32
+ } -> tensor<6x6xi32>
+ // CHECK: return
+ return %3 : tensor<6x6xi32>
+ }
+}
diff --git a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
new file mode 100644
index 0000000000000..dd2eee2f7def8
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward-backward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+ mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+ func.func @test_forward() -> tensor<6x6xi32> {
+ %c1_i32 = arith.constant 1 : i32
+ // CHECK: tensor.empty()
+ %0 = tensor.empty() : tensor<6x6xi32>
+ // CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]]
+ %sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+ %annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32>
+ %1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32>
+ %2 = tensor.empty() : tensor<6x6xi32>
+ // CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]]
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1
+ : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
+ ^bb0(%in: i32, %in_2: i32, %out: i32):
+ %9 = arith.addi %in, %in_2 : i32
+ linalg.yield %9 : i32
+ } -> tensor<6x6xi32>
+ %sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding
+ %annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32>
+ // CHECK: return
+ return %annotated_col : tensor<6x6xi32>
+ }
+}
diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
new file mode 100644
index 0000000000000..98e9931b8de94
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
+ mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+ func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
+ %c1_i32 = arith.constant 1 : i32
+ // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
+ %0 = tensor.empty() : tensor<6x6xi32>
+ // CHECK: [[v1:%.*]] = linalg.fill ins
+ // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
+ %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
+ %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+ %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
+ // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
+ // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
+ %3 = tensor.empty() : tensor<6x6xi32>
+ // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
+ // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+ // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
+ // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
+ %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
+ : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
+ ^bb0(%in: i32, %in_2: i32, %out: i32):
+ %9 = arith.addi %in, %in_2 : i32
+ linalg.yield %9 : i32
+ } -> tensor<6x6xi32>
+ %c0_i32 = arith.constant 0 : i32
+ %6 = tensor.empty() : tensor<i32>
+ %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
+ // CHECK: [[vreduced:%.*]] = linalg.reduce ins
+ // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial = sum [0] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
+ %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1]
+ (%in: i32, %init: i32) {
+ %9 = arith.addi %in, %init : i32
+ linalg.yield %9 : i32
+ }
+ // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
+ %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
+ %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
+ return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/145897
More information about the Mlir-commits
mailing list