[Mlir-commits] [mlir] 79eb406 - [mlir][mesh, MPI] Mesh2mpi (#104566)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 28 01:38:41 PST 2024
Author: Frank Schlimbach
Date: 2024-11-28T09:38:38Z
New Revision: 79eb406a67fe08458548289da72cda18248a9313
URL: https://github.com/llvm/llvm-project/commit/79eb406a67fe08458548289da72cda18248a9313
DIFF: https://github.com/llvm/llvm-project/commit/79eb406a67fe08458548289da72cda18248a9313.diff
LOG: [mlir][mesh, MPI] Mesh2mpi (#104566)
Pass for lowering `Mesh` to `MPI`.
Initial commit lowers `UpdateHaloOp` only.
Added:
mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
Modified:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Dialect/MPI/IR/MPIOps.cpp
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
mlir/test/Dialect/Mesh/ops.mlir
mlir/test/Dialect/Mesh/spmdization.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
new file mode 100644
index 00000000000000..44a1cc0adb6a0c
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
@@ -0,0 +1,27 @@
+//===- MeshToMPI.h - Convert Mesh to MPI dialect ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Lowers Mesh communication operations (updateHalo, AllGater, ...)
+/// to MPI primitives.
+std::unique_ptr<::mlir::Pass> createConvertMeshToMPIPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 2ab32836c80b1c..b577aa83946f23 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -51,6 +51,7 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 76ac386057ef2b..d722bd1f3e296a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -883,6 +883,29 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
];
}
+//===----------------------------------------------------------------------===//
+// MeshToMPI
+//===----------------------------------------------------------------------===//
+
+def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
+ let summary = "Convert Mesh dialect to MPI dialect.";
+ let constructor = "mlir::createConvertMeshToMPIPass()";
+ let description = [{
+ This pass converts communication operations from the Mesh dialect to the
+ MPI dialect.
+ If it finds a global named "static_mpi_rank" it will use that splat value
+ instead of calling MPI_Comm_rank. This allows optimizations like constant
+ shape propagation and fusion because shard/partition sizes depend on the
+ rank.
+ }];
+ let dependentDialects = [
+ "memref::MemRefDialect",
+ "mpi::MPIDialect",
+ "scf::SCFDialect",
+ "bufferization::BufferizationDialect"
+ ];
+}
+
//===----------------------------------------------------------------------===//
// NVVMToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index 768f376e24da4c..240fac5104c34f 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -84,6 +84,7 @@ def MPI_SendOp : MPI_Op<"send", []> {
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
"type($ref) `,` type($tag) `,` type($rank)"
"(`->` type($retval)^)?";
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -114,6 +115,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> {
let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
"type($ref) `,` type($tag) `,` type($rank)"
"(`->` type($retval)^)?";
+ let hasCanonicalizer = 1;
}
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 19498fe5a32d69..6039e61a93fadc 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -156,6 +156,40 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
];
}
+def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+ let summary =
+ "For given mesh index get the linear indices of the direct neighbor processes along the given split.";
+ let description = [{
+ Example:
+ ```
+ mesh.mesh @mesh0(shape = 10x20x30)
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %idx = mesh.neighbors_linear_indices on @mesh[%c1, %c2, %c3] split_axes = [1] : index
+ ```
+ The above returns two indices, `633` and `693`, which correspond to the
+ index of the previous process `(1, 1, 3)`, and the next process
+ `(1, 3, 3) along the split axis `1`.
+
+ A negative value is returned if there is no neighbor in the respective
+ direction along the given `split_axes`.
+ }];
+ let arguments = (ins FlatSymbolRefAttr:$mesh,
+ Variadic<Index>:$device,
+ Mesh_MeshAxesAttr:$split_axes);
+ let results = (outs Index:$neighbor_down, Index:$neighbor_up);
+ let assemblyFormat = [{
+ `on` $mesh `[` $device `]`
+ `split_axes` `=` $split_axes
+ attr-dict `:` type(results)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Sharding operations.
//===----------------------------------------------------------------------===//
@@ -1058,12 +1092,12 @@ def Mesh_ShiftOp : Mesh_CollectiveCommunicationOpBase<"shift", [
}
def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
+ Pure,
DestinationStyleOpInterface,
TypesMatchWith<
"result has same type as destination",
"result", "destination", "$_self">,
- DeclareOpInterfaceMethods<SymbolUserOpInterface>,
- AttrSizedOperandSegments
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Update halo data.";
let description = [{
@@ -1072,7 +1106,7 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
on the remote devices. Changes might be caused by mutating operations
and/or if the new halo regions are larger than the existing ones.
- Source and destination might have
diff erent halo sizes.
+ Destination is supposed to be initialized with the local data (not halos).
Assumes all devices hold tensors with same-sized halo data as specified
by `source_halo_sizes/static_source_halo_sizes` and
@@ -1084,25 +1118,21 @@ def Mesh_UpdateHaloOp : Mesh_Op<"update_halo", [
}];
let arguments = (ins
- AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$source,
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$destination,
FlatSymbolRefAttr:$mesh,
Mesh_MeshAxesArrayAttr:$split_axes,
- Variadic<I64>:$source_halo_sizes,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_source_halo_sizes,
- Variadic<I64>:$destination_halo_sizes,
- DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_destination_halo_sizes
+ Variadic<I64>:$halo_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$static_halo_sizes
);
let results = (outs
AnyTypeOf<[AnyNon0RankedMemRef, AnyNon0RankedTensor]>:$result
);
let assemblyFormat = [{
- $source `into` $destination
+ $destination
`on` $mesh
`split_axes` `=` $split_axes
- (`source_halo_sizes` `=` custom<DynamicIndexList>($source_halo_sizes, $static_source_halo_sizes)^)?
- (`destination_halo_sizes` `=` custom<DynamicIndexList>($destination_halo_sizes, $static_destination_halo_sizes)^)?
- attr-dict `:` type($source) `->` type($result)
+ (`halo_sizes` `=` custom<DynamicIndexList>($halo_sizes, $static_halo_sizes)^)?
+ attr-dict `:` type($result)
}];
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 6651d87162257f..62461c0cea08af 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -41,6 +41,7 @@ add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
+add_subdirectory(MeshToMPI)
add_subdirectory(NVGPUToNVVM)
add_subdirectory(NVVMToLLVM)
add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
new file mode 100644
index 00000000000000..95815a683f6d6a
--- /dev/null
+++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_mlir_conversion_library(MLIRMeshToMPI
+ MeshToMPI.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRIR
+ MLIRLinalgTransforms
+ MLIRMemRefDialect
+ MLIRPass
+ MLIRMeshDialect
+ MLIRMPIDialect
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
new file mode 100644
index 00000000000000..6dd89ecf4d5c2d
--- /dev/null
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -0,0 +1,440 @@
+//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation of Mesh communication ops tp MPI ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "mesh-to-mpi"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+// Create operations converting a linear index to a multi-dimensional index
+static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
+ Value linearIndex,
+ ValueRange dimensions) {
+ int n = dimensions.size();
+ SmallVector<Value> multiIndex(n);
+
+ for (int i = n - 1; i >= 0; --i) {
+ multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
+ if (i > 0) {
+ linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
+ }
+ }
+
+ return multiIndex;
+}
+
+// Create operations converting a multi-dimensional index to a linear index
+Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
+ ValueRange dimensions) {
+
+ auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
+ auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
+
+ for (int i = multiIndex.size() - 1; i >= 0; --i) {
+ auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
+ linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
+ stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
+ }
+
+ return linearIndex;
+}
+
+struct ConvertProcessMultiIndexOp
+ : public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op,
+ mlir::PatternRewriter &rewriter) const override {
+
+ // Currently converts its linear index to a multi-dimensional index.
+
+ SymbolTableCollection symbolTableCollection;
+ auto loc = op.getLoc();
+ auto meshOp = getMesh(op, symbolTableCollection);
+ // For now we only support static mesh shapes
+ if (ShapedType::isDynamicShape(meshOp.getShape())) {
+ return mlir::failure();
+ }
+
+ SmallVector<Value> dims;
+ llvm::transform(
+ meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+ return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
+ });
+ auto rank =
+ rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp).getResult();
+ auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
+
+ // optionally extract subset of mesh axes
+ auto axes = op.getAxes();
+ if (!axes.empty()) {
+ SmallVector<Value> subIndex;
+ for (auto axis : axes) {
+ subIndex.push_back(mIdx[axis]);
+ }
+ mIdx = subIndex;
+ }
+
+ rewriter.replaceOp(op, mIdx);
+ return mlir::success();
+ }
+};
+
+struct ConvertProcessLinearIndexOp
+ : public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
+ mlir::PatternRewriter &rewriter) const override {
+
+ // Finds a global named "static_mpi_rank" it will use that splat value.
+ // Otherwise it defaults to mpi.comm_rank.
+
+ auto loc = op.getLoc();
+ auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
+ if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
+ op, rankOpName)) {
+ if (auto initTnsr = globalOp.getInitialValueAttr()) {
+ auto val = cast<DenseElementsAttr>(initTnsr).getSplatValue<int64_t>();
+ rewriter.replaceOp(op,
+ rewriter.create<arith::ConstantIndexOp>(loc, val));
+ return mlir::success();
+ }
+ }
+ auto rank =
+ rewriter
+ .create<mpi::CommRankOp>(
+ op.getLoc(), TypeRange{mpi::RetvalType::get(op->getContext()),
+ rewriter.getI32Type()})
+ .getRank();
+ rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
+ rank);
+ return mlir::success();
+ }
+};
+
+struct ConvertNeighborsLinearIndicesOp
+ : public mlir::OpRewritePattern<mlir::mesh::NeighborsLinearIndicesOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op,
+ mlir::PatternRewriter &rewriter) const override {
+
+ // Computes the neighbors indices along a split axis by simply
+ // adding/subtracting 1 to the current index in that dimension.
+ // Assigns -1 if neighbor is out of bounds.
+
+ auto axes = op.getSplitAxes();
+ // For now only single axis sharding is supported
+ if (axes.size() != 1) {
+ return mlir::failure();
+ }
+
+ auto loc = op.getLoc();
+ SymbolTableCollection symbolTableCollection;
+ auto meshOp = getMesh(op, symbolTableCollection);
+ auto mIdx = op.getDevice();
+ auto orgIdx = mIdx[axes[0]];
+ SmallVector<Value> dims;
+ llvm::transform(
+ meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
+ return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
+ });
+ auto dimSz = dims[axes[0]];
+ auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1).getResult();
+ auto minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1).getResult();
+ auto atBorder = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sle, orgIdx,
+ rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult());
+ auto down = rewriter.create<scf::IfOp>(
+ loc, atBorder,
+ [&](OpBuilder &builder, Location loc) {
+ builder.create<scf::YieldOp>(loc, minus1);
+ },
+ [&](OpBuilder &builder, Location loc) {
+ SmallVector<Value> tmp = mIdx;
+ tmp[axes[0]] =
+ rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, one)
+ .getResult();
+ builder.create<scf::YieldOp>(
+ loc, multiToLinearIndex(loc, rewriter, tmp, dims));
+ });
+ atBorder = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, orgIdx,
+ rewriter.create<arith::SubIOp>(loc, dimSz, one).getResult());
+ auto up = rewriter.create<scf::IfOp>(
+ loc, atBorder,
+ [&](OpBuilder &builder, Location loc) {
+ builder.create<scf::YieldOp>(loc, minus1);
+ },
+ [&](OpBuilder &builder, Location loc) {
+ SmallVector<Value> tmp = mIdx;
+ tmp[axes[0]] =
+ rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one)
+ .getResult();
+ builder.create<scf::YieldOp>(
+ loc, multiToLinearIndex(loc, rewriter, tmp, dims));
+ });
+ rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
+ return mlir::success();
+ }
+};
+
+struct ConvertUpdateHaloOp
+ : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mlir::mesh::UpdateHaloOp op,
+ mlir::PatternRewriter &rewriter) const override {
+
+ // The input/output memref is assumed to be in C memory order.
+ // Halos are exchanged as 2 blocks per dimension (one for each side: down
+ // and up). For each haloed dimension `d`, the exchanged blocks are
+ // expressed as multi-dimensional subviews. The subviews include potential
+ // halos of higher dimensions `dh > d`, no halos for the lower dimensions
+ // `dl < d` and for dimension `d` the currently exchanged halo only.
+ // By iterating form higher to lower dimensions this also updates the halos
+ // in the 'corners'.
+ // memref.subview is used to read and write the halo data from and to the
+ // local data. Because subviews and halos can have mixed dynamic and static
+ // shapes, OpFoldResults are used whenever possible.
+
+ SymbolTableCollection symbolTableCollection;
+ auto loc = op.getLoc();
+
+ // convert a OpFoldResult into a Value
+ auto toValue = [&rewriter, &loc](OpFoldResult &v) {
+ return v.is<Value>()
+ ? v.get<Value>()
+ : rewriter.create<::mlir::arith::ConstantOp>(
+ loc,
+ rewriter.getIndexAttr(
+ cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+ };
+
+ auto dest = op.getDestination();
+ auto dstShape = cast<ShapedType>(dest.getType()).getShape();
+ Value array = dest;
+ if (isa<RankedTensorType>(array.getType())) {
+ // If the destination is a memref, we need to cast it to a tensor
+ auto tensorType = MemRefType::get(
+ dstShape, cast<ShapedType>(array.getType()).getElementType());
+ array = rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array)
+ .getResult();
+ }
+ auto rank = cast<ShapedType>(array.getType()).getRank();
+ auto opSplitAxes = op.getSplitAxes().getAxes();
+ auto mesh = op.getMesh();
+ auto meshOp = getMesh(op, symbolTableCollection);
+ auto haloSizes =
+ getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter);
+ // subviews need Index values
+ for (auto &sz : haloSizes) {
+ if (sz.is<Value>()) {
+ sz = rewriter
+ .create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
+ sz.get<Value>())
+ .getResult();
+ }
+ }
+
+ // most of the offset/size/stride data is the same for all dims
+ SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
+ auto currHaloDim = -1; // halo sizes are provided for split dimensions only
+ // we need the actual shape to compute offsets and sizes
+ for (auto i = 0; i < rank; ++i) {
+ auto s = dstShape[i];
+ if (ShapedType::isDynamic(s)) {
+ shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
+ } else {
+ shape[i] = rewriter.getIndexAttr(s);
+ }
+
+ if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
+ ++currHaloDim;
+ // the offsets for lower dim sstarts after their down halo
+ offsets[i] = haloSizes[currHaloDim * 2];
+
+ // prepare shape and offsets of highest dim's halo exchange
+ auto _haloSz =
+ rewriter
+ .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
+ toValue(haloSizes[currHaloDim * 2 + 1]))
+ .getResult();
+ // the halo shape of lower dims exlude the halos
+ dimSizes[i] =
+ rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
+ .getResult();
+ } else {
+ dimSizes[i] = shape[i];
+ }
+ }
+
+ auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
+ auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+ auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
+ auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+
+ SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+ rewriter.getIndexType());
+ auto myMultiIndex =
+ rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
+ .getResult();
+ // traverse all split axes from high to low dim
+ for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
+ auto splitAxes = opSplitAxes[dim];
+ if (splitAxes.empty()) {
+ continue;
+ }
+ assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
+ // Get the linearized ids of the neighbors (down and up) for the
+ // given split
+ auto tmp = rewriter
+ .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
+ splitAxes)
+ .getResults();
+ // MPI operates on i32...
+ Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), tmp[0]),
+ rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), tmp[1])};
+
+ auto lowerRecvOffset = rewriter.getIndexAttr(0);
+ auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
+ auto upperRecvOffset = rewriter.create<arith::SubIOp>(
+ loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
+ auto upperSendOffset = rewriter.create<arith::SubIOp>(
+ loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
+
+ // Make sure we send/recv in a way that does not lead to a dead-lock.
+ // The current approach is by far not optimal, this should be at least
+ // be a red-black pattern or using MPI_sendrecv.
+ // Also, buffers should be re-used.
+ // Still using temporary contiguous buffers for MPI communication...
+ // Still yielding a "serialized" communication pattern...
+ auto genSendRecv = [&](bool upperHalo) {
+ auto orgOffset = offsets[dim];
+ dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
+ : haloSizes[currHaloDim * 2];
+ // Check if we need to send and/or receive
+ // Processes on the mesh borders have only one neighbor
+ auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
+ auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+ auto hasFrom = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, from, zero);
+ auto hasTo = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, to, zero);
+ auto buffer = rewriter.create<memref::AllocOp>(
+ loc, dimSizes, cast<ShapedType>(array.getType()).getElementType());
+ // if has neighbor: copy halo data from array to buffer and send
+ rewriter.create<scf::IfOp>(
+ loc, hasTo, [&](OpBuilder &builder, Location loc) {
+ offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
+ : OpFoldResult(upperSendOffset);
+ auto subview = builder.create<memref::SubViewOp>(
+ loc, array, offsets, dimSizes, strides);
+ builder.create<memref::CopyOp>(loc, subview, buffer);
+ builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
+ builder.create<scf::YieldOp>(loc);
+ });
+ // if has neighbor: receive halo data into buffer and copy to array
+ rewriter.create<scf::IfOp>(
+ loc, hasFrom, [&](OpBuilder &builder, Location loc) {
+ offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
+ : OpFoldResult(lowerRecvOffset);
+ builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
+ auto subview = builder.create<memref::SubViewOp>(
+ loc, array, offsets, dimSizes, strides);
+ builder.create<memref::CopyOp>(loc, buffer, subview);
+ builder.create<scf::YieldOp>(loc);
+ });
+ rewriter.create<memref::DeallocOp>(loc, buffer);
+ offsets[dim] = orgOffset;
+ };
+
+ genSendRecv(false);
+ genSendRecv(true);
+
+ // the shape for lower dims include higher dims' halos
+ dimSizes[dim] = shape[dim];
+ // -> the offset for higher dims is always 0
+ offsets[dim] = rewriter.getIndexAttr(0);
+ // on to next halo
+ --currHaloDim;
+ }
+
+ if (isa<MemRefType>(op.getResult().getType())) {
+ rewriter.replaceOp(op, array);
+ } else {
+ assert(isa<RankedTensorType>(op.getResult().getType()));
+ rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>(
+ loc, op.getResult().getType(), array,
+ /*restrict=*/true, /*writable=*/true));
+ }
+ return mlir::success();
+ }
+};
+
+struct ConvertMeshToMPIPass
+ : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
+ using Base::Base;
+
+ /// Run the dialect converter on the module.
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+
+ patterns.insert<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
+ ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
+ ctx);
+
+ (void)mlir::applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns));
+ }
+};
+
+} // namespace
+
+// Create a pass that convert Mesh to MPI
+std::unique_ptr<::mlir::Pass> mlir::createConvertMeshToMPIPass() {
+ return std::make_unique<ConvertMeshToMPIPass>();
+}
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index ddd77b8f586ee0..dcb55d8921364f 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -7,12 +7,52 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::mpi;
+namespace {
+
+// If input memref has dynamic shape and is a cast and if the cast's input has
+// static shape, fold the cast's static input into the given operation.
+template <typename OpT>
+struct FoldCast final : public mlir::OpRewritePattern<OpT> {
+ using mlir::OpRewritePattern<OpT>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpT op,
+ mlir::PatternRewriter &b) const override {
+ auto mRef = op.getRef();
+ if (mRef.getType().hasStaticShape()) {
+ return mlir::failure();
+ }
+ auto defOp = mRef.getDefiningOp();
+ if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) {
+ return mlir::failure();
+ }
+ auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource();
+ if (!src.getType().hasStaticShape()) {
+ return mlir::failure();
+ }
+ op.getRefMutable().assign(src);
+ return mlir::success();
+ }
+};
+} // namespace
+
+void mlir::mpi::SendOp::getCanonicalizationPatterns(
+ mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+ results.add<FoldCast<mlir::mpi::SendOp>>(context);
+}
+
+void mlir::mpi::RecvOp::getCanonicalizationPatterns(
+ mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+ results.add<FoldCast<mlir::mpi::RecvOp>>(context);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index c5570d8ee8a443..33460ff25e9e45 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -837,6 +837,25 @@ void ProcessLinearIndexOp::getAsmResultNames(
setNameFn(getResult(), "proc_linear_idx");
}
+//===----------------------------------------------------------------------===//
+// mesh.neighbors_linear_indices op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return success();
+}
+
+void NeighborsLinearIndicesOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getNeighborDown(), "down_linear_idx");
+ setNameFn(getNeighborUp(), "up_linear_idx");
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index b4d088cbd7088d..327ea0991e4e1e 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -496,11 +496,9 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
sourceShard.getLoc(),
RankedTensorType::get(outShape,
sourceShard.getType().getElementType()),
- sourceShard, initOprnd, mesh.getSymName(),
+ initOprnd, mesh.getSymName(),
MeshAxesArrayAttr::get(builder.getContext(),
sourceSharding.getSplitAxes()),
- sourceSharding.getDynamicHaloSizes(),
- sourceSharding.getStaticHaloSizes(),
targetSharding.getDynamicHaloSizes(),
targetSharding.getStaticHaloSizes())
.getResult();
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
new file mode 100644
index 00000000000000..25d585a108c8ae
--- /dev/null
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -0,0 +1,208 @@
+// RUN: mlir-opt %s -convert-mesh-to-mpi -canonicalize -split-input-file | FileCheck %s
+
+// -----
+// CHECK: mesh.mesh @mesh0
+mesh.mesh @mesh0(shape = 3x4x5)
+func.func @process_multi_index() -> (index, index, index) {
+ // CHECK: mpi.comm_rank : !mpi.retval, i32
+ // CHECK-DAG: %[[v4:.*]] = arith.remsi
+ // CHECK-DAG: %[[v0:.*]] = arith.remsi
+ // CHECK-DAG: %[[v1:.*]] = arith.remsi
+ %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+ // CHECK: return %[[v1]], %[[v0]], %[[v4]] : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_linear_index
+func.func @process_linear_index() -> index {
+ // CHECK: %[[RES:.*]], %[[rank:.*]] = mpi.comm_rank : !mpi.retval, i32
+ // CHECK: %[[cast:.*]] = arith.index_cast %[[rank]] : i32 to index
+ %0 = mesh.process_linear_index on @mesh0 : index
+ // CHECK: return %[[cast]] : index
+ return %0 : index
+}
+
+// CHECK-LABEL: func @neighbors_dim0
+func.func @neighbors_dim0(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ // CHECK-DAG: [[up:%.*]] = arith.constant 44 : index
+ // CHECK-DAG: [[down:%.*]] = arith.constant 4 : index
+ %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [0] : index, index
+ // CHECK: return [[down]], [[up]] : index, index
+ return %idx#0, %idx#1 : index, index
+}
+
+// CHECK-LABEL: func @neighbors_dim1
+func.func @neighbors_dim1(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ // CHECK-DAG: [[up:%.*]] = arith.constant 29 : index
+ // CHECK-DAG: [[down:%.*]] = arith.constant -1 : index
+ %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [1] : index, index
+ // CHECK: return [[down]], [[up]] : index, index
+ return %idx#0, %idx#1 : index, index
+}
+
+// CHECK-LABEL: func @neighbors_dim2
+func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ // CHECK-DAG: [[up:%.*]] = arith.constant -1 : index
+ // CHECK-DAG: [[down:%.*]] = arith.constant 23 : index
+ %idx:2 = mesh.neighbors_linear_indices on @mesh0[%c1, %c0, %c4] split_axes = [2] : index, index
+ // CHECK: return [[down]], [[up]] : index, index
+ return %idx#0, %idx#1 : index, index
+}
+
+// -----
+// CHECK: mesh.mesh @mesh0
+mesh.mesh @mesh0(shape = 3x4x5)
+memref.global constant @static_mpi_rank : memref<index> = dense<24>
+func.func @process_multi_index() -> (index, index, index) {
+ // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+ %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+ // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_linear_index
+func.func @process_linear_index() -> index {
+ // CHECK: %[[c24:.*]] = arith.constant 24 : index
+ %0 = mesh.process_linear_index on @mesh0 : index
+ // CHECK: return %[[c24]] : index
+ return %0 : index
+}
+
+// -----
+mesh.mesh @mesh0(shape = 3x4x5)
+// CHECK-LABEL: func @update_halo_1d_first
+func.func @update_halo_1d_first(
+ // CHECK-SAME: [[arg0:%.*]]: memref<120x120x120xi8>
+ %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
+ // CHECK: memref.subview [[arg0]][115, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+ // CHECK: mpi.send(
+ // CHECK-SAME: : memref<2x120x120xi8>, i32, i32
+ // CHECK: mpi.recv(
+ // CHECK-SAME: : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: memref.subview [[arg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+ // CHECK: memref.subview [[arg0]][2, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
+ // CHECK: mpi.send(
+ // CHECK-SAME: : memref<3x120x120xi8>, i32, i32
+ // CHECK: mpi.recv(
+ // CHECK-SAME: : memref<3x120x120xi8>, i32, i32
+ // CHECK-NEXT: memref.subview [[arg0]][117, 0, 0] [3, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<3x120x120xi8
+ %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 3] : memref<120x120x120xi8>
+ // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
+ return %res : memref<120x120x120xi8>
+}
+
+// -----
+mesh.mesh @mesh0(shape = 3x4x5)
+memref.global constant @static_mpi_rank : memref<index> = dense<24>
+// CHECK-LABEL: func @update_halo_3d
+func.func @update_halo_3d(
+ // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
+ %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
+ // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32
+ // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
+ // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
+ // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
+ // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
+ // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+ // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+ // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+ // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
+ // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
+ // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
+ // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
+ // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
+ // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
+ // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
+ // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
+ // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8>
+ // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
+ // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
+ // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
+ %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
+ // CHECK: return [[varg0]] : memref<120x120x120xi8>
+ return %res : memref<120x120x120xi8>
+}
+
+// CHECK-LABEL: func @update_halo_3d_tensor
+func.func @update_halo_3d_tensor(
+ // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
+ %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
+ // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32
+ // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
+ // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
+ // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
+ // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : memref<120x120x120xi8>
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
+ // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
+ // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+ // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+ // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+ // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
+ // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
+ // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
+ // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
+ // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
+ // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
+ // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
+ // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
+ // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8>
+ // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
+ // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
+ // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
+ // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8>
+ %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
+ // CHECK: return [[v1]] : tensor<120x120x120xi8>
+ return %res : tensor<120x120x120xi8>
+}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index d8df01c3d6520d..978de4939ee77c 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -615,16 +615,16 @@ func.func @update_halo(
// CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
%arg0 : memref<12x12xi8>) {
// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
- // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] into %[[ARG]] on @mesh0
+ // CHECK-NEXT: %[[UH1:.*]] = mesh.update_halo %[[ARG]] on @mesh0
// CHECK-SAME: split_axes = {{\[\[}}0]]
- // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8> -> memref<12x12xi8>
+ // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
%c2 = arith.constant 2 : i64
- %uh1 = mesh.update_halo %arg0 into %arg0 on @mesh0 split_axes = [[0]]
- source_halo_sizes = [2, %c2] : memref<12x12xi8> -> memref<12x12xi8>
- // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[ARG]] into %[[UH1]] on @mesh0
+ %uh1 = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
+ halo_sizes = [2, %c2] : memref<12x12xi8>
+ // CHECK-NEXT: %[[UH2:.*]] = mesh.update_halo %[[UH1]] on @mesh0
// CHECK-SAME: split_axes = {{\[\[}}0], [1]]
- // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8> -> memref<12x12xi8>
- %uh2 = mesh.update_halo %arg0 into %uh1 on @mesh0 split_axes = [[0], [1]]
- source_halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8> -> memref<12x12xi8>
+ // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2] : memref<12x12xi8>
+ %uh2 = mesh.update_halo %uh1 on @mesh0 split_axes = [[0], [1]]
+ halo_sizes = [2, 2, %c2, 2] : memref<12x12xi8>
return
}
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 22ddb72569835d..c1b96fda0f4a74 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -226,7 +226,7 @@ func.func @test_shard_update_halo(%arg0: tensor<1200x1200xi64>) -> tensor<1200x1
%sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] : !mesh.sharding
// CHECK: %[[T:.*]] = tensor.empty() : tensor<304x1200xi64>
// CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][2, 0] [300, 1200] [1, 1] : tensor<300x1200xi64> into tensor<304x1200xi64>
- // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} destination_halo_sizes = [2, 2] : tensor<300x1200xi64> -> tensor<304x1200xi64>
+ // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh_1d_4 split_axes = {{\[\[0]]}} halo_sizes = [2, 2] : tensor<304x1200xi64>
%sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
%sharding_0 = mesh.sharding @mesh_1d_4 split_axes = [[0]] halo_sizes = [2, 2] : !mesh.sharding
%sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
@@ -242,7 +242,7 @@ func.func @test_shard_update_halo2d(%arg0: tensor<1200x1200xi64>) -> tensor<1200
%sharding = mesh.sharding @mesh4x4 split_axes = [[0], [1]] : !mesh.sharding
// CHECK: %[[T:.*]] = tensor.empty() : tensor<303x307xi64>
// CHECK: %[[inserted_slice:.*]] = tensor.insert_slice %[[IN1]] into %[[T]][1, 3] [300, 300] [1, 1] : tensor<300x300xi64> into tensor<303x307xi64>
- // CHECK: %[[UH:.*]] = mesh.update_halo %[[IN1]] into %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] destination_halo_sizes = [1, 2, 3, 4] : tensor<300x300xi64> -> tensor<303x307xi64>
+ // CHECK: %[[UH:.*]] = mesh.update_halo %[[inserted_slice]] on @mesh4x4 split_axes = {{\[\[}}0], [1]] halo_sizes = [1, 2, 3, 4] : tensor<303x307xi64>
%sharding_annotated = mesh.shard %arg0 to %sharding : tensor<1200x1200xi64>
%sharding_0 = mesh.sharding @mesh4x4 split_axes = [[0], [1]] halo_sizes = [1, 2, 3, 4] : !mesh.sharding
%sharding_annotated_1 = mesh.shard %sharding_annotated to %sharding_0 : tensor<1200x1200xi64>
More information about the Mlir-commits
mailing list