[Mlir-commits] [mlir] [mlir][shard, mpi] Adding Shard/MPI reduce_scatter (PR #184189)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 2 10:06:09 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

<details>
<summary>Changes</summary>

The partition pass often creates a pattern like
```
red = shard.all_reduce...grid_axes=x...
res = shard.all_slice...grid_axes=x...
```
which is basically a `reduce-scatter` operation. A good communication implementation will have lower communication costs when using `reduce-scatter` over the `allreduce/slice` one.

For this, this PR also
- introduces a simplify pass, which finds such patterns and replaces it with the equivalent `reduce-scatter`
- promotes the test-pass `test-shard-optimizations` to a proper pass and adds the new pattern
- sanitizes the `shard.reduce_scatter` op
- adds a new `mpi.reduce_scatter_block` op
- lowers `shard.reduce_scatter` to MPI
- lowers `mpi-reduce_scatter_block` to llvm

https://github.com/llvm/lighthouse/pull/58 works nicely with the changes.

---

Patch is 57.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/184189.diff


23 Files Affected:

- (modified) mlir/include/mlir/Dialect/MPI/IR/MPIOps.td (+34) 
- (modified) mlir/include/mlir/Dialect/Shard/IR/ShardOps.td (+10-10) 
- (modified) mlir/include/mlir/Dialect/Shard/Transforms/Partition.h (+1-1) 
- (modified) mlir/include/mlir/Dialect/Shard/Transforms/Passes.td (+17) 
- (renamed) mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h (+8-8) 
- (modified) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+74-1) 
- (modified) mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp (+152-2) 
- (modified) mlir/lib/Dialect/Shard/IR/ShardOps.cpp (+3-3) 
- (modified) mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt (+2-1) 
- (renamed) mlir/lib/Dialect/Shard/Transforms/Simplify.cpp (+81-27) 
- (modified) mlir/test/Conversion/MPIToLLVM/mpitollvm.mlir (+24) 
- (modified) mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir (+34) 
- (modified) mlir/test/Dialect/MPI/mpiops.mlir (+6) 
- (modified) mlir/test/Dialect/Shard/all-scatter-op-lowering.mlir (+6-6) 
- (modified) mlir/test/Dialect/Shard/canonicalization.mlir (+4-4) 
- (modified) mlir/test/Dialect/Shard/folding.mlir (+1-1) 
- (modified) mlir/test/Dialect/Shard/invalid.mlir (+10-10) 
- (modified) mlir/test/Dialect/Shard/ops.mlir (+10-10) 
- (renamed) mlir/test/Dialect/Shard/simplify.mlir (+84-1) 
- (modified) mlir/test/lib/Dialect/Shard/CMakeLists.txt (-1) 
- (modified) mlir/test/lib/Dialect/Shard/TestReshardingPartition.cpp (+1-1) 
- (removed) mlir/test/lib/Dialect/Shard/TestSimplifications.cpp (-47) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 7e68b152fdf75..0ed2a98545870 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -333,6 +333,40 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
                        "(`->` type($retval)^)?";
 }
 
+//===----------------------------------------------------------------------===//
+// ReduceScatterBlockOp
+//===----------------------------------------------------------------------===//
+
+def MPI_ReduceScatterBlockOp : MPI_Op<"reduce_scatter_block", []> {
+  let summary = "Equivalent to `MPI_Reduce_scatter_block(sendbuf, recvbuf, "
+                "recvcount, dtype, op, comm)`";
+  let description = [{
+    MPI_Reduce_scatter_block first performs an element-wise reduction on the
+    sendbuf across all processes in the communicator, then scatters the result
+    by distributing equal-sized blocks to each process into recvbuf.
+
+    The `op` attribute specifies the reduction operation to be performed.
+    Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
+    supported.
+
+    This operation can optionally return an `!mpi.retval` value that can be used
+    to check for errors.
+  }];
+
+  let arguments = (
+    ins AnyNon0RankedMemRef : $sendbuf,
+    AnyNon0RankedMemRef : $recvbuf,
+    MPI_ReductionOpEnum : $op,
+    MPI_Comm : $comm
+  );
+
+  let results = (outs Optional<MPI_Retval>:$retval);
+
+  let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `,` $comm `)` "
+                       "attr-dict `:` type($sendbuf) `,` type($recvbuf) "
+                       "(`->` type($retval)^)?";
+}
+
 //===----------------------------------------------------------------------===//
 // BarrierOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
index 6ef7c72d305ee..60f6d6fc1ffe4 100644
--- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
+++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td
@@ -901,13 +901,13 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
     `grid_axes` using the specified reduction method. The reduction is performed
     element-wise across the tensor pieces from all devices in the group.
     After reduction, the reduction result is scattered (split and distributed)
-    across the device group along `scatter_axis`.
+    across the device group along `scatter_dim`.
     Example:
     ```
     shard.grid @grid0(shape = 2x2)
     ...
     %1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1]
-      reduction = <max> scatter_axis = 0
+      reduction = <max> scatter_dim = 0
       : tensor<2x2xf32> -> tensor<1x2xf64>
     ```
     Input:
@@ -940,17 +940,17 @@ def Shard_ReduceScatterOp : Shard_CollectiveCommunicationOpBase<"reduce_scatter"
     ```
   }];
   let arguments = !con(commonArgs, (ins
-    AnyNon0RankedTensor:$input,
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
     DefaultValuedAttr<Shard_ReductionKindAttr, "::mlir::shard::ReductionKind::Sum">:$reduction,
-    IndexAttr:$scatter_axis
+    IndexAttr:$scatter_dim
   ));
   let results = (outs
-    AnyRankedTensor:$result
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
   );
   let assemblyFormat = [{
     $input `on` $grid (`grid_axes` `=` $grid_axes^)?
     (`reduction` `=` $reduction^)?
-    `scatter_axis` `=` $scatter_axis
+    `scatter_dim` `=` $scatter_dim
     attr-dict `:` type($input) `->` type($result)
   }];
   let hasCanonicalizer = 1;
@@ -964,7 +964,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
   let summary = "Scatter over a device grid.";
   let description = [{
     For each device group defined by `grid_axes`, the input tensor on the `root`
-    device is split along axis `scatter_axis` and distributed across the group.
+    device is split along axis `scatter_dim` and distributed across the group.
     The content of the input on all other (non-root) devices is ignored.
     The `root` device is defined by its in-group multi-index.
 
@@ -972,7 +972,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
     ```
     shard.grid @grid0(shape = 2x2)
     %1 = shard.scatter %0 on @grid0 grid_axes = [0]
-      scatter_axis = 0
+      scatter_dim = 0
       root = [1]
       : (tensor<2x2xi8>) -> tensor<1x2xi8>
     ```
@@ -1011,7 +1011,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
   }];
   let arguments = !con(commonArgs, (ins
     AnyNon0RankedTensor:$input,
-    IndexAttr:$scatter_axis,
+    IndexAttr:$scatter_dim,
     DenseI64ArrayAttr:$root,
     Variadic<Index>:$root_dynamic
   ));
@@ -1020,7 +1020,7 @@ def Shard_ScatterOp : Shard_CollectiveCommunicationOpBase<"scatter", [
   );
   let assemblyFormat = [{
     $input `on` $grid (`grid_axes` `=` $grid_axes^)?
-    `scatter_axis` `=` $scatter_axis
+    `scatter_dim` `=` $scatter_dim
     `root` `=` custom<DynamicIndexList>($root_dynamic, $root)
     attr-dict `:` functional-type(operands, results)
   }];
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
index 37903765903db..ba12002b9f1eb 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Partition.h
@@ -1,4 +1,4 @@
-//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
+//===- Partition.h - Shard Partition ----------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
index bbc6a1977b13e..575c176217e61 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Passes.td
@@ -44,6 +44,23 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
   ];
 }
 
+def ShardSimplify : Pass<"shard-simplify"> {
+  let summary = "Shard simplify patterns.";
+  let description = [{
+    Applies simplification patterns on the Shard dialect operations.
+    This includes:
+    - All-reduce endomorphism simplification, e.g. transforming
+      `all_reduce_sum(x) + all_reduce_sum(y)` into `all_reduce_sum(x + y)`.
+    - Folding `AllSliceOp(AllReduceOp)` into `ReduceScatterOp` when both ops
+      share the same grid and grid_axes.
+    - Folding static grid shapes into constants.
+  }];
+  let dependentDialects = [
+    "arith::ArithDialect",
+    "shard::ShardDialect"
+  ];
+}
+
 def Partition : InterfacePass<"shard-partition", "mlir::FunctionOpInterface"> {
   let summary = "Partition a function into SPMD form.";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h
similarity index 89%
rename from mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h
rename to mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h
index 45ae758ec14c2..f3f4feffd8a71 100644
--- a/mlir/include/mlir/Dialect/Shard/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Shard/Transforms/Simplify.h
@@ -1,4 +1,4 @@
-//===- Simplifications.h - Shard Simplifications ----------------*- C++ -*-===//
+//===- Simplify.h - Shard Simplify ------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
-#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
+#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
+#define MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
 
 #include "mlir/Dialect/Shard/IR/ShardOps.h"
 #include "mlir/IR/PatternMatch.h"
@@ -37,8 +37,8 @@ namespace shard {
 // Will not work with some op `f(x, y, z)` where only `x` and `y` form
 // the algebraic structure.
 template <typename AlgebraicOp>
-void populateAllReduceEndomorphismSimplificationPatterns(
-    RewritePatternSet &patterns, ReductionKind reduction) {
+void populateAllReduceEndomorphismSimplifyPatterns(RewritePatternSet &patterns,
+                                                   ReductionKind reduction) {
   auto getEndomorphismOpOperand = [](Operation *op) {
     auto allReduceOp = llvm::cast<AllReduceOp>(op);
     return &allReduceOp.getInputMutable();
@@ -105,12 +105,12 @@ void populateAllReduceEndomorphismSimplificationPatterns(
 
 // It is invalid to change ops that declare symbols during the application of
 // these patterns, because symbolTableCollection is used to cache them.
-void populateSimplificationPatterns(
-    RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void populateSimplifyPatterns(RewritePatternSet &patterns,
+                              SymbolTableCollection &symbolTableCollection);
 void populateFoldingPatterns(RewritePatternSet &patterns,
                              SymbolTableCollection &symbolTableCollection);
 
 } // namespace shard
 } // namespace mlir
 
-#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFICATIONS_H
+#endif // MLIR_DIALECT_SHARD_TRANSFORMS_SIMPLIFY_H
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 0dbc0a126a5c6..5087b96adf8ec 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -911,6 +912,78 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ReduceScatterBlockOpLowering
+//===----------------------------------------------------------------------===//
+
+struct ReduceScatterBlockOpLowering
+    : public ConvertOpToLLVMPattern<mpi::ReduceScatterBlockOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(mpi::ReduceScatterBlockOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    MLIRContext *context = rewriter.getContext();
+    Type i32 = rewriter.getI32Type();
+    Type i64 = rewriter.getI64Type();
+    Type elemType = op.getSendbuf().getType().getElementType();
+    int64_t sRank = op.getSendbuf().getType().getRank();
+    int64_t rRank = op.getRecvbuf().getType().getRank();
+
+    // ptrType `!llvm.ptr`
+    Type ptrType = LLVM::LLVMPointerType::get(context);
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    auto mpiTraits = MPIImplTraits::get(moduleOp);
+    auto [sendPtr, sendSize] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), sRank, elemType);
+    auto [recvPtr, recvSize] =
+        getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), rRank, elemType);
+
+    // If input and output are the same, request in-place operation.
+    if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
+      sendPtr = LLVM::ConstantOp::create(
+          rewriter, loc, i64,
+          reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
+      sendPtr = LLVM::IntToPtrOp::create(rewriter, loc, ptrType, sendPtr);
+    }
+
+    Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
+    Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
+    Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
+
+    Value nRanks =
+        createOrFoldCommSize(rewriter, loc, op.getComm(), adaptor.getComm());
+    Value expected = LLVM::UDivOp::create(rewriter, loc, i32, sendSize, nRanks);
+    Value sizeIsValid = LLVM::ICmpOp::create(
+        rewriter, loc, LLVM::ICmpPredicate::eq, expected, recvSize);
+    cf::AssertOp::create(rewriter, loc, sizeIsValid,
+                         "Send buffer's size must be the receive buffer's size "
+                         "times the number of ranks");
+
+    // 'int MPI_Reduce_scatter_block(const void *sendbuf, void *recvbuf,
+    //     int recvcount, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
+    auto funcType = LLVM::LLVMFunctionType::get(
+        i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
+              comm.getType()});
+    // get or create function declaration:
+    LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(
+        moduleOp, loc, rewriter, "MPI_Reduce_scatter_block", funcType);
+
+    // replace op with function call
+    auto funcCall = LLVM::CallOp::create(
+        rewriter, loc, funcDecl,
+        ValueRange{sendPtr, recvPtr, recvSize, dataType, mpiOp, comm});
+
+    if (op.getRetval())
+      rewriter.replaceOp(op, funcCall.getResult());
+    else
+      rewriter.eraseOp(op);
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertToLLVMPatternInterface implementation
 //===----------------------------------------------------------------------===//
@@ -943,7 +1016,7 @@ void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
   patterns.add<CommRankOpLowering, CommSizeOpLowering, CommSplitOpLowering,
                CommWorldOpLowering, FinalizeOpLowering, InitOpLowering,
                SendOpLowering, RecvOpLowering, AllGatherOpLowering,
-               AllReduceOpLowering>(converter);
+               AllReduceOpLowering, ReduceScatterBlockOpLowering>(converter);
 }
 
 void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 1db14e60c5a7f..68448fcbd4427 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -26,7 +26,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Shard/IR/ShardDialect.h"
 #include "mlir/Dialect/Shard/IR/ShardOps.h"
-#include "mlir/Dialect/Shard/Transforms/Simplifications.h"
+#include "mlir/Dialect/Shard/Transforms/Simplify.h"
 #include "mlir/Dialect/Shard/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -618,6 +618,156 @@ struct ConvertAllReduceOp : public CommOpPattern<AllReduceOp> {
   }
 };
 
+struct ConvertReduceScatterOp : public CommOpPattern<ReduceScatterOp> {
+  using CommOpPattern::CommOpPattern;
+
+  // shard.reduce_scatter reduces and then scatters along a specified
+  // scatter-dim. mpi.reduce_scatter_block always scatters along the first
+  // dimension. Hence, if scatter-dim != 0, we need to rearrange the input
+  // data by expanding the scatter-dim into {nRanks, output_scatter_dim} and
+  // transposing nRanks to the first dimension.
+
+  LogicalResult
+  matchAndRewrite(ReduceScatterOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto gridAxes = adaptor.getGridAxes();
+    int64_t scatterDim = adaptor.getScatterDimAttr().getInt();
+
+    SymbolTableCollection symbolTableCollection;
+    FailureOr<GridOp> gridOp = checkGrid(op, symbolTableCollection);
+    if (failed(gridOp))
+      return failure();
+
+    ImplicitLocOpBuilder ib(op.getLoc(), rewriter);
+    Value rawInput = adaptor.getInput();
+    auto inShapedType = cast<ShapedType>(rawInput.getType());
+    auto elemType = inShapedType.getElementType();
+    MemRefType outType = getMemrefType(cast<ShapedType>(op.getType()));
+    auto inputShape = inShapedType.getShape();
+    auto outputShape = outType.getShape();
+    int64_t inputDimOnAxis = inputShape[scatterDim];
+    int64_t outputDimOnAxis = outputShape[scatterDim];
+
+    for (size_t i = 0; i < outputShape.size(); ++i)
+      if (outputShape[i] != inputShape[i] &&
+          i != static_cast<size_t>(scatterDim))
+        return op.emitError(
+            "Result and input shapes must match along non-scatter axes.");
+    if (outputDimOnAxis == 0)
+      return op.emitError(
+          "Output size along the scatter axis must be non-zero.");
+    if (inputDimOnAxis % outputDimOnAxis != 0)
+      return op.emitError(
+          "Input size along the scatter axis must be an exact "
+          "multiple of the output size along the scatter axis.");
+
+    if (!memref::isStaticShapeAndContiguousRowMajor(outType))
+      return op.emitError("Result must be a statically shaped memref in "
+                          "contiguous row-major layout.");
+
+    int64_t nRanks = inputDimOnAxis / outputDimOnAxis;
+
+    // Verify that nRanks matches the number of devices along the grid axes.
+    int64_t gridGroupSize =
+        collectiveProcessGroupSize(gridAxes, gridOp->getShape());
+    if (nRanks != gridGroupSize)
+      return op.emitError()
+             << "Expected the scatter factor (" << nRanks
+             << ") to match the number of devices along grid_axes ("
+             << gridGroupSize << ").";
+
+    // Get the right communicator.
+    Value comm = getComm(*gridOp, gridAxes, ib);
+
+    Value mpiInput;
+    if (scatterDim == 0) {
+      // scatter_dim == 0 maps directly to MPI_Reduce_scatter_block.
+      // Input must be contiguous for MPI.
+      Value input = getAsMemref(rawInput, ib);
+      MemRefType inType = cast<MemRefType>(input.getType());
+      if (!memref::isStaticShapeAndContiguousRowMajor(inType))
+        return op.emitError("Input must be a statically shaped memref in "
+                            "contiguous row-major layout.");
+      mpiInput = input;
+    } else {
+      // For scatter_dim != 0 we rearrange the input so the scatter factor
+      // becomes the first dimension.
+      //
+      // 1. Get a tensor representation of the input (avoid memref->tensor
+      //    round-trip if the input is already a tensor).
+      Value tensorInput = rawInput;
+      if (!isa<RankedTensorType>(rawInput.getType())) {
+        auto inTensorType = RankedTensorType::get(inputShape, elemType);
+        tensorInput =
+            bufferization::ToTensorOp::create(ib, inTensorType, rawInput, true);
+      }
+
+      // 2. Expand the scatter dim from {d0, ..., d_sd, ..., dN} to
+      //    {d0, ..., nRanks, o_sd, ..., dN}.
+      SmallVector<int64_t> expandedShape;
+      SmallVector<ReassociationIndices> expandReassociation;
+      int64_t expandedIdx = 0;
+      for (int64_t i = 0; i < static_cast<int64_t>(inputShape.size()); ++i) {
+        if (i == scatterDim) {
+          expandedShape.push_back(nRanks);
+          expandedShape.push_back(outputDimOnAxis);
+          expandReassociation.push_back({expandedIdx, expandedIdx + 1});
+          expandedIdx += 2;
+        } else {
+          expandedShape.push_back(inputShape[i]);
+          expandReassociation.push_back({expandedIdx});
+          expandedIdx += 1;
+        }
+      }
+      auto expandedType = RankedTensorType::get(expandedShape, elemType);
+      tensorInput = tensor::ExpandShapeOp::create(ib, expandedType, tensorInput,
+                                                  expandReassociation);
+
+      // 3. Transpose to move nRanks (at position scatterDim) to position 0:
+      //    {d0, ..., nRanks, o_sd, ..., dN} -> {nRanks, d0, ..., o_sd, ..., dN}
+      SmallVector<int64_t> permutation, transposedShape;
+      permutation.emplace_back(scatterDim);
+      for (int64_t i = 0; i < scatterDim; ++i)
+        permutation.emplace_back(i);
+      for (int64_t i = scatterDim + 1; i < (int64_t)expandedShape.size(); ++i)
+        permutation.emplace_back(i);
+      for (auto p : permutation)
+        transposedShape.emplace_back(expandedShape[p]);
+
+      Value permOutput = tensor::EmptyOp::create(ib, transposedShape, elemType);
+      tensorInput =
+          linalg::TransposeOp::create(ib, tensorInput, permOutput, permutation)
+              ->getResult(0);
+
+      // 4. ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/184189


More information about the Mlir-commits mailing list