[Mlir-commits] [mlir] [mlir][shard, mpi] Lowering shard.allgather to MPI (PR #177202)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 21 08:58:44 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

<details>
<summary>Changes</summary>

- lowering `shard.allgather` to `mpi.allgather`
- fixing lowering of `shard.allreduce`
- minor refactoring

---

Patch is 21.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/177202.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Shard/IR/ShardOps.td (+2-2) 
- (modified) mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h (+6-6) 
- (modified) mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp (+115-70) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Shard/Transforms/Transforms.cpp (+9-8) 
- (modified) mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir (+43-12) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 5e68f75ee08bf..6ef7c72d305ee 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -530,11 +530,11 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [
     ```
   }];
   let arguments = !con(commonArgs, (ins
-    AnyNon0RankedTensor:$input,
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
     IndexAttr:$gather_axis
   ));
   let results = (outs
-    AnyNon0RankedTensor:$result
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
   );
   let assemblyFormat = [{
     $input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
index 57d65e687ea35..1ddd1985389bc 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h
@@ -39,14 +39,14 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
                                  ImplicitLocOpBuilder &builder);
 
 // Get process linear index along the given grid axes.
-TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
-                                               ArrayRef<GridAxis> gridAxes,
-                                               ImplicitLocOpBuilder &builder);
+TypedValue<IndexType>
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+                         ArrayRef<GridAxis> gridAxes = {});
 // Get process linear index from a multi-index along the given grid axes .
 TypedValue<IndexType>
-createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
-                         ArrayRef<GridAxis> gridAxes,
-                         ImplicitLocOpBuilder &builder);
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+                         ValueRange processInGroupMultiIndex,
+                         ArrayRef<GridAxis> gridAxes = {});
 
 } // namespace shard
 } // namespace mlir
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index b0831dc05abb1..1865914de9d84 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MPI/IR/MPI.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Shard/IR/ShardDialect.h"
 #include "mlir/Dialect/Shard/IR/ShardOps.h"
@@ -507,103 +508,147 @@ static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
   }
 }
 
-struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
-  using OpConversionPattern::OpConversionPattern;
+template <typename CommOp>
+struct CommOpPattern : public OpConversionPattern<CommOp> {
+  using OpConversionPattern<CommOp>::OpConversionPattern;
 
-  LogicalResult
-  matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    SymbolTableCollection symbolTableCollection;
-    auto grid = adaptor.getGrid();
-    mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection);
-    if (!gridOp)
-      return op->emitError() << "No grid found for AllReduceOp";
-    if (ShapedType::isDynamicShape(gridOp.getShape()))
-      return op->emitError()
-             << "Dynamic grid shape not supported in AllReduceOp";
-
-    ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
-    Value input = adaptor.getInput();
-    auto inputShape = cast<ShapedType>(input.getType()).getShape();
+  MemRefType getMemrefType(ShapedType tensorType) const {
+    return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+  }
 
+  Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder) const {
+    auto itype = input.getType();
     // If the source is a memref, cast it to a tensor.
-    if (isa<RankedTensorType>(input.getType())) {
-      auto memrefType = MemRefType::get(
-          inputShape, cast<ShapedType>(input.getType()).getElementType());
+    if (isa<RankedTensorType>(itype)) {
+      auto memrefType = getMemrefType(cast<ShapedType>(itype));
       input = bufferization::ToBufferOp::create(iBuilder, memrefType, input);
+    } else {
+      assert(isa<MemRefType>(itype) &&
+             "expected input to be of MemRefType or TensorType");
     }
-    MemRefType inType = cast<MemRefType>(input.getType());
-
-    // Get the actual shape to allocate the buffer.
-    SmallVector<OpFoldResult> shape(inType.getRank());
-    for (auto i = 0; i < inType.getRank(); ++i) {
-      auto s = inputShape[i];
-      if (ShapedType::isDynamic(s))
-        shape[i] = memref::DimOp::create(iBuilder, input, s).getResult();
-      else
-        shape[i] = iBuilder.getIndexAttr(s);
-    }
+    return input;
+  }
 
-    // Allocate buffer and copy input to buffer.
-    Value buffer = memref::AllocOp::create(
-        iBuilder, shape, cast<ShapedType>(op.getType()).getElementType());
-    linalg::CopyOp::create(iBuilder, input, buffer);
+  FailureOr<GridOp> checkGrid(CommOp op,
+                              SymbolTableCollection &symbolTableCollection,
+                              bool allowDynamic = false) const {
+    GridOp gridOp = getGrid(op, symbolTableCollection);
+    if (!gridOp)
+      return op->emitError() << "Missing grid symbol.";
+    if (!allowDynamic && ShapedType::isDynamicShape(gridOp.getShape()))
+      return op->emitError() << "Dynamic grid shape not supported.";
+    return gridOp;
+  }
 
-    // Get an MPI_Comm_split for the AllReduce operation.
+  Value getComm(GridOp &gridOp, ::llvm::ArrayRef<int16_t> gridAxes,
+                ImplicitLocOpBuilder &iBuilder) const {
+    // Get an MPI_Comm_split for a given grid and axes.
     // The color is the linear index of the process in the grid along the
-    // non-reduced axes. The key is the linear index of the process in the grid
-    // along the reduced axes.
-    SmallVector<Type> indexResultTypes(gridOp.getShape().size(),
-                                       iBuilder.getIndexType());
-    SmallVector<Value> myMultiIndex =
-        ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid)
-            .getResult();
-    Value zero = arith::ConstantIndexOp::create(iBuilder, 0);
-    SmallVector<Value> multiKey(myMultiIndex.size(), zero);
-
-    auto redAxes = adaptor.getGridAxes();
-    for (auto axis : redAxes) {
-      multiKey[axis] = myMultiIndex[axis];
-      myMultiIndex[axis] = zero;
+    // non-'grid-axes'. The key is the linear index of the process in the grid
+    // along the grid-axes.
+    SmallVector<GridAxis> otherAxes;
+    for (GridAxis i = 0; i < static_cast<GridAxis>(gridOp.getShape().size());
+         ++i) {
+      if (!llvm::is_contained(gridAxes, i))
+        otherAxes.emplace_back(i);
     }
 
+    SmallVector<Type> indexResultTypes(otherAxes.size(),
+                                       iBuilder.getIndexType());
+
     Value color =
-        createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder);
+        createProcessLinearIndex(iBuilder, gridOp.getSymName(), otherAxes);
     color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color);
-    Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder);
+
+    Value key =
+        createProcessLinearIndex(iBuilder, gridOp.getSymName(), gridAxes);
     key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key);
 
     // Finally split the communicator
-    auto commType = mpi::CommType::get(op->getContext());
+    auto commType = mpi::CommType::get(gridOp->getContext());
     Value commWorld = mpi::CommWorldOp::create(iBuilder, commType);
-    auto comm =
-        mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
-            .getNewcomm();
-
-    Value buffer1d = buffer;
-    // Collapse shape to 1d if needed
-    if (inType.getRank() > 1) {
-      ReassociationIndices reassociation(inType.getRank());
-      std::iota(reassociation.begin(), reassociation.end(), 0);
-      buffer1d = memref::CollapseShapeOp::create(
-          iBuilder, buffer, ArrayRef<ReassociationIndices>(reassociation));
-    }
+    return mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key)
+        .getNewcomm();
+  }
+};
 
+struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
+  using CommOpPattern::CommOpPattern;
+
+  LogicalResult
+  matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SymbolTableCollection symbolTableCollection;
+    FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+    if (failed(gridOp))
+      return failure();
+    ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+    Value input = getAsMemref(adaptor.getInput(), iBuilder);
+    MemRefType inType = cast<MemRefType>(input.getType());
+    if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+      return op.emitError(
+          "Expected static shaped memref in contiguous row-major layout.");
+    MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+    if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+      return op.emitError(
+          "Expected static shaped memref in contiguous row-major layout.");
+
+    // Allocate buffer and copy input to buffer.
+    Value buffer = memref::AllocOp::create(iBuilder, outType);
+    linalg::CopyOp::create(iBuilder, input, buffer);
+    // Get the right communicator
+    Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
     // Create the MPI AllReduce operation.
-    mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d,
+    mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer, buffer,
                              getMPIReductionOp(adaptor.getReductionAttr()),
                              comm);
 
-    // If the destination is a memref, cast it to a tensor
+    // If the destination is a tensor, cast it to a tensor
     if (isa<RankedTensorType>(op.getType()))
       buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer,
                                                  true);
-
     rewriter.replaceOp(op, buffer);
     return success();
   }
 };
 
+struct ConvertAllGatherOp : public CommOpPattern<AllGatherOp> {
+  using CommOpPattern::CommOpPattern;
+
+  LogicalResult
+  matchAndRewrite(AllGatherOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SymbolTableCollection symbolTableCollection;
+    FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+    if (failed(gridOp))
+      return failure();
+    ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+    Value input = getAsMemref(adaptor.getInput(), iBuilder);
+    MemRefType inType = cast<MemRefType>(input.getType());
+    if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+      return op.emitError(
+          "Expected static shaped memref in contiguous row-major layout.");
+    MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+    if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+      return op.emitError(
+          "Expected static shaped memref in contiguous row-major layout.");
+
+    // Get the right communicator
+    Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder);
+    // Allocate output buffer
+    Value output = memref::AllocOp::create(iBuilder, outType);
+    // Create the MPI AllGather operation.
+    mpi::AllGatherOp::create(iBuilder, TypeRange(), input, output, comm);
+
+    // If the destination is a tensor, cast it to a tensor
+    if (isa<RankedTensorType>(op.getType()))
+      output = bufferization::ToTensorOp::create(iBuilder, op.getType(), output,
+                                                 true);
+    rewriter.replaceOp(op, output);
+    return success();
+  }
+};
+
 struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -895,8 +940,8 @@ struct ConvertShardToMPIPass
 
     patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
                  ConvertGetShardingOp, ConvertShardingOp, ConvertShardShapeOp,
-                 ConvertAllReduceOp, ConvertProcessLinearIndexOp>(typeConverter,
-                                                                  ctxt);
+                 ConvertAllGatherOp, ConvertAllReduceOp,
+                 ConvertProcessLinearIndexOp>(typeConverter, ctxt);
     SymbolTableCollection stc;
     populateProcessMultiIndexOpLoweringPatterns(patterns, stc);
     populateAllSliceOpLoweringPatterns(patterns, stc);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
index 0ae2a9cc0318c..d0165595f9fb6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp
@@ -128,7 +128,7 @@ static Value createDestinationPassingStyleInitOperand(
     ArrayRef<GridAxis> reductionGridAxes, GridOp gridOp,
     ImplicitLocOpBuilder &builder) {
   Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex(
-      gridOp.getSymName(), reductionGridAxes, builder);
+      builder, gridOp.getSymName(), reductionGridAxes);
   Value zero = arith::ConstantIndexOp::create(builder, 0);
   Value isLeadProcess = arith::CmpIOp::create(
       builder, builder.getI1Type(), arith::CmpIPredicate::eq,
diff --git a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
index b433b8b0be7b2..835bc443d4b2a 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp
@@ -208,9 +208,9 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes,
 }
 
 TypedValue<IndexType>
-createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
-                         ArrayRef<GridAxis> gridAxes,
-                         ImplicitLocOpBuilder &builder) {
+createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid,
+                         ValueRange processInGroupMultiIndex,
+                         ArrayRef<GridAxis> gridAxes) {
   Operation::result_range processGroupShape =
       GridShapeOp::create(builder, grid, gridAxes).getResult();
   OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
@@ -224,11 +224,12 @@ createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex,
   return cast<TypedValue<IndexType>>(res);
 }
 
-TypedValue<IndexType> createProcessLinearIndex(StringRef grid,
-                                               ArrayRef<GridAxis> gridAxes,
-                                               ImplicitLocOpBuilder &builder) {
+TypedValue<IndexType> createProcessLinearIndex(ImplicitLocOpBuilder &builder,
+                                               StringRef grid,
+                                               ArrayRef<GridAxis> gridAxes) {
   return createProcessLinearIndex(
-      grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
-      gridAxes, builder);
+      builder, grid,
+      ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(),
+      gridAxes);
 }
 } // namespace mlir::shard
diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
index a0b6bfaf6fd3d..d4741102e4d3f 100644
--- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
+++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir
@@ -102,15 +102,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
   func.func @allreduce_tensor(
     // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
     %arg0 : tensor<3x4xf32>) -> tensor<3x4xf32> {
-    // CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32
     // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32
     // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32>
     // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
     // CHECK: linalg.copy ins([[v0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
     // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
-    // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
-    // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+    // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf32>, memref<3x4xf32>
     // CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32>
     %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32>
     // CHECK: return [[v2]] : tensor<3x4xf32>
@@ -121,14 +120,13 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
   func.func @allreduce_memref(
     // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
     %arg0 : memref<3x4xf32>) -> memref<3x4xf32> {
-    // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK: [[vc1_i32:%.*]] = arith.constant 1 : i32
     // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
     // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32>
     // CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>)
     // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
-    // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32>
-    // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32>
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+    // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf32>, memref<3x4xf32>
     %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32>
     // CHECK: return [[valloc]] : memref<3x4xf32>
     return %0 : memref<3x4xf32>
@@ -138,18 +136,51 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } {
   func.func @allreduce_new_type(
     // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32>
     %arg0 : memref<3x4xf32>) -> memref<3x4xf64> {
-    // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32
+    // CHECK: [[vc1_i32:%.*]] = arith.constant 1 : i32
     // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32
     // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf64>
     // CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf64>)
     // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm
-    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm
-    // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64>
-    // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64>
+    // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm
+    // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf64>, memref<3x4xf64>
     %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64>
     // CHECK: return [[valloc]] : memref<3x4xf64>
     return %0 : memref<3x4xf64>
   }
+
+  // CHECK-LABEL: func @allgather_tensor
+  func.func @allgather_tensor(
+      // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32>
+      // CHECK-SAME: -> tensor<3x20xf32>
+      %arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> {
+    // CHECK: [[vc2_i32:%.*]]...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/177202


More information about the Mlir-commits mailing list