[Mlir-commits] [mlir] [mlir][shard, mpi] Lowering shard.allgather to MPI (PR #177202)
Tuomas Kärnä
llvmlistbot at llvm.org
Thu Jan 22 01:54:08 PST 2026
================
@@ -138,18 +136,51 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
func.func @allreduce_new_type(
// CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
%arg0 : memref<3x4xf32>) -> memref<3x4xf64> {
- // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
- // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
+ // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
+ // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
// CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf64>
// CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf64>)
// CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
- // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
- // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64>
- // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64>
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+ // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf64>, memref<3x4xf64>
%0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
// CHECK: return [[valloc]] : memref<3x4xf64>
return %0 : memref<3x4xf64>
}
+
+ // CHECK-LABEL: func @allgather_tensor
+ func.func @allgather_tensor(
+ // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
+ // CHECK-SAME: -> tensor<3x20xf32>
+ %arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
+ // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
+ // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
+ // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
+ // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
+ // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
+ // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x20xf32>
----------------
tkarna wrote:
Moving forward we should have a mechanism to deallocate the output buffer if safe to do so.
https://github.com/llvm/llvm-project/pull/177202
More information about the Mlir-commits
mailing list