[Mlir-commits] [mlir] Lower allreduce (PR #144716)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 18 07:47:05 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Frank Schlimbach (fschlimb)

<details>
<summary>Changes</summary>

a "git pull" messed up the previous #<!-- -->144060. The difference to #<!-- -->144060 is that I applied separated conversion and rewrite patterns in the pass and I renamed the reduction op as aws nit-requested..

@<!-- -->tkarna 

---

Patch is 67.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144716.diff


14 Files Affected:

- (modified) mlir/include/mlir/Conversion/Passes.td (+2) 
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPI.h (+1) 
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPI.td (+1-1) 
- (modified) mlir/include/mlir/Dialect/MPI/IR/MPIOps.td (+7-5) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+5) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+2-2) 
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h (+5-5) 
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h (+5) 
- (modified) mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp (+31-31) 
- (modified) mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp (+132-39) 
- (modified) mlir/lib/Dialect/MPI/IR/MPIOps.cpp (+35) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+23) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+16-6) 
- (modified) mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir (+207-150) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b496ee0114910..5a864865adffc 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -905,6 +905,8 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
     shard/partition sizes depend on the rank.
   }];
   let dependentDialects = [
+    "affine::AffineDialect",
+    "arith::ArithDialect",
     "memref::MemRefDialect",
     "mpi::MPIDialect",
     "scf::SCFDialect",
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.h b/mlir/include/mlir/Dialect/MPI/IR/MPI.h
index f06b911ce3fe3..2b6743cd008c6 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.h
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.h
@@ -12,6 +12,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
 
 //===----------------------------------------------------------------------===//
 // MPIDialect
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
index f2837e71df060..0c62a1794e19e 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td
@@ -230,7 +230,7 @@ def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
 def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
 def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;
 
-def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
+def MPI_ReductionOpEnum : I32EnumAttr<"MPI_ReductionOpEnum", "MPI operation class", [
       MPI_OpNull,
       MPI_OpMax,
       MPI_OpMin,
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index d78aa92d201e7..935e0f785ef0c 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/MPI/IR/MPI.td"
 include "mlir/Dialect/MPI/IR/MPITypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
 
 class MPI_Op<string mnemonic, list<Trait> traits = []>
     : Op<MPI_Dialect, mnemonic, traits>;
@@ -41,7 +42,7 @@ def MPI_InitOp : MPI_Op<"init", []> {
 // CommWorldOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
+def MPI_CommWorldOp : MPI_Op<"comm_world", [Pure]> {
   let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
   let description = [{
     This operation returns the predefined MPI_COMM_WORLD communicator.
@@ -56,7 +57,7 @@ def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
 // CommRankOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
+def MPI_CommRankOp : MPI_Op<"comm_rank", [Pure]> {
   let summary = "Get the current rank, equivalent to "
                 "`MPI_Comm_rank(comm, &rank)`";
   let description = [{
@@ -72,13 +73,14 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
   );
 
   let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
 // CommSizeOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
+def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
   let summary = "Get the size of the group associated to the communicator, "
                 "equivalent to `MPI_Comm_size(comm, &size)`";
   let description = [{
@@ -100,7 +102,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
 // CommSplitOp
 //===----------------------------------------------------------------------===//
 
-def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
+def MPI_CommSplitOp : MPI_Op<"comm_split", [Pure]> {
   let summary = "Partition the group associated with the given communicator into "
                 "disjoint subgroups";
   let description = [{
@@ -281,7 +283,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
   let arguments = (
     ins AnyMemRef : $sendbuf,
     AnyMemRef : $recvbuf,
-    MPI_OpClassEnum : $op,
+    MPI_ReductionOpEnum : $op,
     MPI_Comm : $comm
   );
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 3878505f8f93f..c4d512b60bc51 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -212,6 +212,11 @@ void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
                                          OpOperand &operand,
                                          OpBuilder &builder);
 
+/// Converts a vector of OpFoldResults (ints) into vector of Values of the
+/// provided type.
+SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
+                                    llvm::ArrayRef<int64_t> statics,
+                                    ValueRange dynamics, Type type = Type());
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index f59c4c4c67517..ac05ee243d7be 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -584,11 +584,11 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
     ```
   }];
   let arguments = !con(commonArgs, (ins
-    AnyRankedTensor:$input,
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
     DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
   ));
   let results = (outs
-    AnyRankedTensor:$result
+    AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
   );
   let assemblyFormat = [{
     $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index c64da29ca6412..3f1041cb25103 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -62,9 +62,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
   auto isEndomorphismOp = [reduction](Operation *op,
                                       std::optional<Operation *> referenceOp) {
     auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
-    if (!allReduceOp ||
-        allReduceOp.getInput().getType().getElementType() !=
-            allReduceOp.getResult().getType().getElementType() ||
+    auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
+    auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
+    if (!allReduceOp || inType.getElementType() != outType.getElementType() ||
         allReduceOp.getReduction() != reduction) {
       return false;
     }
@@ -83,9 +83,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
     }
 
     auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
+    auto refType = cast<ShapedType>(refAllReduceOp.getResult().getType());
     return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
-           allReduceOp.getInput().getType().getElementType() ==
-               refAllReduceOp.getInput().getType().getElementType();
+           inType.getElementType() == refType.getElementType();
   };
   auto isAlgebraicOp = [](Operation *op) {
     return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index be82e2af399dc..f46c0db846088 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -42,6 +42,11 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
 TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
                                                ArrayRef<MeshAxis> meshAxes,
                                                ImplicitLocOpBuilder &builder);
+// Get process linear index from a multi-index along the given mesh axes .
+TypedValue<IndexType>
+createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
+                         ArrayRef<MeshAxis> meshAxes,
+                         ImplicitLocOpBuilder &builder);
 
 } // namespace mesh
 } // namespace mlir
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 5575b295ae20a..d4deff5b88070 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -116,7 +116,7 @@ class MPIImplTraits {
   /// enum value.
   virtual Value getMPIOp(const Location loc,
                          ConversionPatternRewriter &rewriter,
-                         mpi::MPI_OpClassEnum opAttr) = 0;
+                         mpi::MPI_ReductionOpEnum opAttr) = 0;
 };
 
 //===----------------------------------------------------------------------===//
@@ -199,49 +199,49 @@ class MPICHImplTraits : public MPIImplTraits {
   }
 
   Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
-                 mpi::MPI_OpClassEnum opAttr) override {
+                 mpi::MPI_ReductionOpEnum opAttr) override {
     int32_t op = MPI_NO_OP;
     switch (opAttr) {
-    case mpi::MPI_OpClassEnum::MPI_OP_NULL:
+    case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
       op = MPI_NO_OP;
       break;
-    case mpi::MPI_OpClassEnum::MPI_MAX:
+    case mpi::MPI_ReductionOpEnum::MPI_MAX:
       op = MPI_MAX;
       break;
-    case mpi::MPI_OpClassEnum::MPI_MIN:
+    case mpi::MPI_ReductionOpEnum::MPI_MIN:
       op = MPI_MIN;
       break;
-    case mpi::MPI_OpClassEnum::MPI_SUM:
+    case mpi::MPI_ReductionOpEnum::MPI_SUM:
       op = MPI_SUM;
       break;
-    case mpi::MPI_OpClassEnum::MPI_PROD:
+    case mpi::MPI_ReductionOpEnum::MPI_PROD:
       op = MPI_PROD;
       break;
-    case mpi::MPI_OpClassEnum::MPI_LAND:
+    case mpi::MPI_ReductionOpEnum::MPI_LAND:
       op = MPI_LAND;
       break;
-    case mpi::MPI_OpClassEnum::MPI_BAND:
+    case mpi::MPI_ReductionOpEnum::MPI_BAND:
       op = MPI_BAND;
       break;
-    case mpi::MPI_OpClassEnum::MPI_LOR:
+    case mpi::MPI_ReductionOpEnum::MPI_LOR:
       op = MPI_LOR;
       break;
-    case mpi::MPI_OpClassEnum::MPI_BOR:
+    case mpi::MPI_ReductionOpEnum::MPI_BOR:
       op = MPI_BOR;
       break;
-    case mpi::MPI_OpClassEnum::MPI_LXOR:
+    case mpi::MPI_ReductionOpEnum::MPI_LXOR:
       op = MPI_LXOR;
       break;
-    case mpi::MPI_OpClassEnum::MPI_BXOR:
+    case mpi::MPI_ReductionOpEnum::MPI_BXOR:
       op = MPI_BXOR;
       break;
-    case mpi::MPI_OpClassEnum::MPI_MINLOC:
+    case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
       op = MPI_MINLOC;
       break;
-    case mpi::MPI_OpClassEnum::MPI_MAXLOC:
+    case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
       op = MPI_MAXLOC;
       break;
-    case mpi::MPI_OpClassEnum::MPI_REPLACE:
+    case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
       op = MPI_REPLACE;
       break;
     }
@@ -336,49 +336,49 @@ class OMPIImplTraits : public MPIImplTraits {
   }
 
   Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
-                 mpi::MPI_OpClassEnum opAttr) override {
+                 mpi::MPI_ReductionOpEnum opAttr) override {
     StringRef op;
     switch (opAttr) {
-    case mpi::MPI_OpClassEnum::MPI_OP_NULL:
+    case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
       op = "ompi_mpi_no_op";
       break;
-    case mpi::MPI_OpClassEnum::MPI_MAX:
+    case mpi::MPI_ReductionOpEnum::MPI_MAX:
       op = "ompi_mpi_max";
       break;
-    case mpi::MPI_OpClassEnum::MPI_MIN:
+    case mpi::MPI_ReductionOpEnum::MPI_MIN:
       op = "ompi_mpi_min";
       break;
-    case mpi::MPI_OpClassEnum::MPI_SUM:
+    case mpi::MPI_ReductionOpEnum::MPI_SUM:
       op = "ompi_mpi_sum";
       break;
-    case mpi::MPI_OpClassEnum::MPI_PROD:
+    case mpi::MPI_ReductionOpEnum::MPI_PROD:
       op = "ompi_mpi_prod";
       break;
-    case mpi::MPI_OpClassEnum::MPI_LAND:
+    case mpi::MPI_ReductionOpEnum::MPI_LAND:
       op = "ompi_mpi_land";
       break;
-    case mpi::MPI_OpClassEnum::MPI_BAND:
+    case mpi::MPI_ReductionOpEnum::MPI_BAND:
       op = "ompi_mpi_band";
       break;
-    case mpi::MPI_OpClassEnum::MPI_LOR:
+    case mpi::MPI_ReductionOpEnum::MPI_LOR:
       op = "ompi_mpi_lor";
       break;
-    case mpi::MPI_OpClassEnum::MPI_BOR:
+    case mpi::MPI_ReductionOpEnum::MPI_BOR:
       op = "ompi_mpi_bor";
       break;
-    case mpi::MPI_OpClassEnum::MPI_LXOR:
+    case mpi::MPI_ReductionOpEnum::MPI_LXOR:
       op = "ompi_mpi_lxor";
       break;
-    case mpi::MPI_OpClassEnum::MPI_BXOR:
+    case mpi::MPI_ReductionOpEnum::MPI_BXOR:
       op = "ompi_mpi_bxor";
       break;
-    case mpi::MPI_OpClassEnum::MPI_MINLOC:
+    case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
       op = "ompi_mpi_minloc";
       break;
-    case mpi::MPI_OpClassEnum::MPI_MAXLOC:
+    case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
       op = "ompi_mpi_maxloc";
       break;
-    case mpi::MPI_OpClassEnum::MPI_REPLACE:
+    case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
       op = "ompi_mpi_replace";
       break;
     }
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 823d4d644f586..aaf1d39d48438 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -12,9 +12,9 @@
 
 #include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -22,6 +22,8 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -289,27 +291,15 @@ struct ConvertProcessMultiIndexOp
 
 class ConvertProcessLinearIndexOp
     : public OpConversionPattern<ProcessLinearIndexOp> {
-  int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
 
 public:
   using OpConversionPattern::OpConversionPattern;
 
-  // Constructor accepting worldRank
-  ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
-                              MLIRContext *context, int64_t worldRank = -1)
-      : OpConversionPattern(typeConverter, context), worldRank(worldRank) {}
-
   LogicalResult
   matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
+    // Create mpi::CommRankOp
     Location loc = op.getLoc();
-    if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
-      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
-      return success();
-    }
-
-    // Otherwise call create mpi::CommRankOp
     auto ctx = op.getContext();
     Value commWorld =
         rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx));
@@ -529,6 +519,124 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
   }
 };
 
+static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
+  auto ctx = kind.getContext();
+  switch (kind.getValue()) {
+  case ReductionKind::Sum:
+    return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_SUM);
+  case ReductionKind::Product:
+    return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_PROD);
+  case ReductionKind::Min:
+    return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_MIN);
+  case ReductionKind::Max:
+    return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_MAX);
+  case ReductionKind::BitwiseAnd:
+    return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BAND);
+  case ReductionKind::BitwiseOr:
+    return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BOR);
+  case ReductionKind::BitwiseXor:
+    return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BXOR);
+  default:
+    assert(false && "Unknown/unsupported reduction kind");
+  }
+}
+
+struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SymbolTableCollection symbolTableCollection;
+    auto mesh = adaptor.getMesh();
+    mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection);
+    if (!meshOp)
+      return op->emitError() << "No mesh found for AllReduceOp";
+    if (ShapedType::isDynamicShape(meshOp.getShape()))
+      return op->emitError()
+             << "Dynamic mesh shape not supported in AllReduceOp";
+
+    ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+    Value input = adaptor.getInput();
+    auto inputShape = cast<ShapedType>(input.getType()).getShape();
+
+    // If the source is a memref, cast it to a tensor.
+    if (isa<RankedTensorType>(input.getType())) {
+      auto memrefType = MemRefType::get(
+          inputShape, cast<ShapedType>(input.getType()).getElementType());
+      input = iBuilder.create<bufferization::ToBufferOp>(memrefType, input);
+    }
+    MemRefType inType = cast<MemRefType>(input.getType());
+
+    // Get the actual shape to allocate the buffer.
+    SmallVector<OpFoldResult> shape(inType.getRank());
+    for (auto i = 0; i < inType.getRank(); ++i) {
+      auto s = inputShape[i];
+      if (ShapedType::isDynamic(s))
+        shape[i] = iBuilder.create<memref::DimOp>(input, s).getResult();
+      else
+        shape[i] = iBuilder.getIndexAttr(s);
+    }
+
+    // Allocate buffer and copy input to buffer.
+    Value buffer = iBuilder.create<memref::AllocOp>(
+        shape, cast<ShapedType>(op.getType()).getElementType());
+    iBuilder.create<linalg::CopyOp>(input, buffer);
+
+    // Get an MPI_Comm_split for the AllReduce operation.
+    // The color is the linear index of the process in the mesh along the
+    // non-reduced axes. The key is the linear index of the process in the mesh
+    // along the reduced axes.
+    SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+                                       iBuilder.getIndexType());
+    SmallVector<Value> myMultiIndex =
+        iBuilder.create<ProcessMultiIndexOp>(indexResultTypes, mesh)
+            .getResult();
+    Value zero = iBuilder.create<arith::ConstantIndexOp>(0);
+    SmallVector<Value> multiKey(myMultiIndex.size(), zero);
+
+    auto redAxes = adaptor.getMeshAxes();
+    for (auto axis : redAxes) {
+      multiKey[axis] = myMultiIndex[axis];
+      myMultiIndex[axis] = zero;
+    }
+
+    Value color =
+        createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder);
+    color = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), color);
+    Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder);
+    key = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), key);
+
+    // Finally split the communicator
+    auto commType = mpi::CommType::get(op->getContext());
+    Value commWorld = iBuilder.create<mpi::CommWorldOp>(commType);
+    auto comm =
+        iBuilder.create<mpi::CommSplitOp>(commType, commWorld, color, key)
+            .getNewcomm();
+
+    Value buffer1d = buffer;
+    // Collapse shape to 1d if needed
+    if (inType.getRank() > 1) {
+      ReassociationIndices reassociation(inType.getRank());
+      std::iota(reassociation.begin(), reassociation.end(), 0);
+      buffer1d = iBuilder.create<memref::CollapseShapeOp>(
+          buffer, ArrayRef<ReassociationIndices>(reassociation));
+    }
+
+    // Create the MPI AllReduce operation.
+    iBuilder.create<mpi::AllReduceOp>(
+        TypeRange(), buffer1d, buffer1d,
+        getMPIReductionOp(adaptor.getReductionA...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/144716


More information about the Mlir-commits mailing list