[Mlir-commits] [mlir] [mlir][shard, mpi] Lowering shard.allgather to MPI (PR #177202)

Tuomas Kärnä llvmlistbot at llvm.org
Thu Jan 22 01:54:08 PST 2026


================
@@ -138,18 +136,51 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
   func.func @allreduce_new_type(
     // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
     %arg0 : memref<3x4xf32>) -> memref<3x4xf64> {
-    // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
-    // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
+    // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
+    // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
     // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf64>
     // CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf64>)
     // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
-    // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64>
-    // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64>
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+    // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf64>, memref<3x4xf64>
     %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
     // CHECK: return [[valloc]] : memref<3x4xf64>
     return %0 : memref<3x4xf64>
   }
+
+  // CHECK-LABEL: func @allgather_tensor
+  func.func @allgather_tensor(
+      // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
+      // CHECK-SAME: -> tensor<3x20xf32>
+      %arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
+    // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
+    // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
+    // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
+    // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm
+    // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x20xf32>
----------------
tkarna wrote:

Moving forward we should have a mechanism to deallocate the output buffer if safe to do so.

https://github.com/llvm/llvm-project/pull/177202


More information about the Mlir-commits mailing list