[Mlir-commits] [mlir] [mlir][mesh, MPI] Mesh2mpi (PR #104566)
Frank Schlimbach
llvmlistbot at llvm.org
Wed Aug 21 03:25:23 PDT 2024
https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/104566
>From 6b6751dd16754763dcb0384dc59fcab5d6f4e367 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 14 Aug 2024 19:29:23 +0200
Subject: [PATCH 1/4] initial hack lowering mesh.update_halo to MPI
---
.../mlir/Conversion/MeshToMPI/MeshToMPI.h | 27 +++
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 17 ++
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 33 ++++
mlir/lib/Conversion/CMakeLists.txt | 1 +
mlir/lib/Conversion/MeshToMPI/CMakeLists.txt | 22 +++
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 171 ++++++++++++++++++
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 19 ++
.../MeshToMPI/convert-mesh-to-mpi.mlir | 34 ++++
9 files changed, 325 insertions(+)
create mode 100644 mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
create mode 100644 mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
create mode 100644 mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
new file mode 100644
index 00000000000000..6a2c196da45577
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
@@ -0,0 +1,27 @@
+//===- MeshToMPI.h - Convert Mesh to MPI dialect --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Lowers Mesh communication operations (updateHalo, AllGater, ...)
+/// to MPI primitives.
+std::unique_ptr<Pass> createConvertMeshToMPIPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 208f26489d6c39..ad8e98442ab8bc 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -51,6 +51,7 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7bde9e490e4f4e..f9a6f52a22c6ed 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -869,6 +869,23 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
];
}
+//===----------------------------------------------------------------------===//
+// MeshToMPI
+//===----------------------------------------------------------------------===//
+
+def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
+ let summary = "Convert Mesh dialect to MPI dialect.";
+ let description = [{
+ This pass converts communication operations
+ from the Mesh dialect to operations from the MPI dialect.
+ }];
+ let dependentDialects = [
+ "memref::MemRefDialect",
+ "mpi::MPIDialect",
+ "scf::SCFDialect"
+ ];
+}
+
//===----------------------------------------------------------------------===//
// NVVMToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8f696bbc1a0f6e..9d1684b78f34f2 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -155,6 +155,39 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
];
}
+def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+ let summary =
+ "For given split axes get the linear index the direct neighbor processes.";
+ let description = [{
+ Example:
+ ```
+ %idx = mesh.neighbor_linear_index on @mesh for $device
+ split_axes = $split_axes : index
+ ```
+ Given `@mesh` with shape `(10, 20, 30)`,
+ `device` = `(1, 2, 3)`
+ `$split_axes` = `[1]`
+ it returns the linear indices of the processes at positions `(1, 1, 3)`: `633`
+ and `(1, 3, 3)`: `693`.
+
+ A negative value is returned if `$device` has no neighbor in the given
+ direction along the given `split_axes`.
+ }];
+ let arguments = (ins FlatSymbolRefAttr:$mesh,
+ Variadic<Index>:$device,
+ Mesh_MeshAxesAttr:$split_axes);
+ let results = (outs Index:$neighbor_down, Index:$neighbor_up);
+ let assemblyFormat = [{
+ `on` $mesh `[` $device `]`
+ `split_axes` `=` $split_axes
+ attr-dict `:` type(results)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Sharding operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 813f700c5556e1..3ee237f4e62acd 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -41,6 +41,7 @@ add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
+add_subdirectory(MeshToMPI)
add_subdirectory(NVGPUToNVVM)
add_subdirectory(NVVMToLLVM)
add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
new file mode 100644
index 00000000000000..95815a683f6d6a
--- /dev/null
+++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_mlir_conversion_library(MLIRMeshToMPI
+ MeshToMPI.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRIR
+ MLIRLinalgTransforms
+ MLIRMemRefDialect
+ MLIRPass
+ MLIRMeshDialect
+ MLIRMPIDialect
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
new file mode 100644
index 00000000000000..b4cf9da8497a2d
--- /dev/null
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -0,0 +1,171 @@
+//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation of Mesh communicatin ops tp MPI ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+#define DEBUG_TYPE "mesh-to-mpi"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+struct ConvertMeshToMPIPass
+ : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
+ using Base::Base;
+
+ /// Run the dialect converter on the module.
+ void runOnOperation() override {
+ getOperation()->walk([&](UpdateHaloOp op) {
+ SymbolTableCollection symbolTableCollection;
+ OpBuilder builder(op);
+ auto loc = op.getLoc();
+
+ auto toValue = [&builder, &loc](OpFoldResult &v) {
+ return v.is<Value>()
+ ? v.get<Value>()
+ : builder.create<::mlir::arith::ConstantOp>(
+ loc,
+ builder.getIndexAttr(
+ cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+ };
+
+ auto array = op.getInput();
+ auto rank = array.getType().getRank();
+ auto mesh = op.getMesh();
+ auto meshOp = getMesh(op, symbolTableCollection);
+ auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
+ op.getDynamicHaloSizes(), builder);
+ for (auto &sz : haloSizes) {
+ if (sz.is<Value>()) {
+ sz = builder
+ .create<arith::IndexCastOp>(loc, builder.getIndexType(),
+ sz.get<Value>())
+ .getResult();
+ }
+ }
+
+ SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
+ SmallVector<OpFoldResult> shape(rank);
+ for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
+ if (ShapedType::isDynamic(s)) {
+ shape[i] = builder.create<memref::DimOp>(loc, array, s).getResult();
+ } else {
+ shape[i] = builder.getIndexAttr(s);
+ }
+ }
+
+ auto tagAttr = builder.getI32IntegerAttr(91); // whatever
+ auto tag = builder.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+ auto zeroAttr = builder.getI32IntegerAttr(0); // whatever
+ auto zero = builder.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+ SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+ builder.getIndexType());
+ auto myMultiIndex =
+ builder.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
+ .getResult();
+ auto currHaloDim = 0;
+
+ for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
+ if (!splitAxes.empty()) {
+ auto tmp = builder
+ .create<NeighborsLinearIndicesOp>(
+ loc, mesh, myMultiIndex, splitAxes)
+ .getResults();
+ Value neighbourIDs[2] = {builder.create<arith::IndexCastOp>(
+ loc, builder.getI32Type(), tmp[0]),
+ builder.create<arith::IndexCastOp>(
+ loc, builder.getI32Type(), tmp[1])};
+ auto orgDimSize = shape[dim];
+ auto upperOffset = builder.create<arith::SubIOp>(
+ loc, toValue(shape[dim]), toValue(haloSizes[dim * 2 + 1]));
+
+ // make sure we send/recv in a way that does not lead to a dead-lock
+ // This is by far not optimal, this should be at least MPI_sendrecv
+ // and - probably even more importantly - buffers should be re-used
+ // Currently using temporary, contiguous buffer for MPI communication
+ auto genSendRecv = [&](auto dim, bool upperHalo) {
+ auto orgOffset = offsets[dim];
+ shape[dim] =
+ upperHalo ? haloSizes[dim * 2 + 1] : haloSizes[dim * 2];
+ auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
+ auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+ auto hasFrom = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, from, zero);
+ auto hasTo = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, to, zero);
+ auto buffer = builder.create<memref::AllocOp>(
+ loc, shape, array.getType().getElementType());
+ builder.create<scf::IfOp>(
+ loc, hasTo, [&](OpBuilder &builder, Location loc) {
+ offsets[dim] = upperHalo
+ ? OpFoldResult(builder.getIndexAttr(0))
+ : OpFoldResult(upperOffset);
+ auto subview = builder.create<memref::SubViewOp>(
+ loc, array, offsets, shape, strides);
+ builder.create<memref::CopyOp>(loc, subview, buffer);
+ builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag,
+ to);
+ builder.create<scf::YieldOp>(loc);
+ });
+ builder.create<scf::IfOp>(
+ loc, hasFrom, [&](OpBuilder &builder, Location loc) {
+ offsets[dim] = upperHalo
+ ? OpFoldResult(upperOffset)
+ : OpFoldResult(builder.getIndexAttr(0));
+ builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag,
+ from);
+ auto subview = builder.create<memref::SubViewOp>(
+ loc, array, offsets, shape, strides);
+ builder.create<memref::CopyOp>(loc, buffer, subview);
+ builder.create<scf::YieldOp>(loc);
+ });
+ builder.create<memref::DeallocOp>(loc, buffer);
+ offsets[dim] = orgOffset;
+ };
+
+ genSendRecv(dim, false);
+ genSendRecv(dim, true);
+
+ shape[dim] = builder
+ .create<arith::SubIOp>(
+ loc, toValue(orgDimSize),
+ builder
+ .create<arith::AddIOp>(
+ loc, toValue(haloSizes[dim * 2]),
+ toValue(haloSizes[dim * 2 + 1]))
+ .getResult())
+ .getResult();
+ offsets[dim] = haloSizes[dim * 2];
+ ++currHaloDim;
+ }
+ }
+ });
+ }
+};
+} // namespace
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index c35020b4c20ccc..f25bbbf8e274b6 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -730,6 +730,25 @@ void ProcessLinearIndexOp::getAsmResultNames(
setNameFn(getResult(), "proc_linear_idx");
}
+//===----------------------------------------------------------------------===//
+// mesh.neighbors_linear_indices op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return success();
+}
+
+void NeighborsLinearIndicesOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getNeighborDown(), "down_linear_idx");
+ setNameFn(getNeighborUp(), "up_linear_idx");
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
new file mode 100644
index 00000000000000..9ef826ca0cdace
--- /dev/null
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt %s -split-input-file -convert-mesh-to-mpi | FileCheck %s
+
+// CHECK: mesh.mesh @mesh0
+mesh.mesh @mesh0(shape = 2x2x4)
+
+// -----
+
+// CHECK-LABEL: func @update_halo
+func.func @update_halo_1d(
+ // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+ %arg0 : memref<12x12xi8>) {
+ // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
+ // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+ // CHECK-SAME: split_axes = {{\[\[}}0]]
+ // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
+ %c2 = arith.constant 2 : i64
+ mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
+ halo_sizes = [2, %c2] : memref<12x12xi8>
+ return
+}
+
+func.func @update_halo_2d(
+ // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+ %arg0 : memref<12x12xi8>) {
+ %c2 = arith.constant 2 : i64
+ // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
+ // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
+ // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
+ // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
+ mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
+ halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
+ : memref<12x12xi8>
+ return
+}
>From c370db1ab727853f48bf121966dc76ffceed8ea7 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 16 Aug 2024 10:55:28 +0200
Subject: [PATCH 2/4] dim fixes, proper testing
---
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 306 ++++++++++--------
.../MeshToMPI/convert-mesh-to-mpi.mlir | 179 ++++++++--
2 files changed, 339 insertions(+), 146 deletions(-)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b4cf9da8497a2d..42d885a109ee79 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
-// This file implements a translation of Mesh communicatin ops tp MPI ops.
+// This file implements a translation of Mesh communication ops tp MPI ops.
//
//===----------------------------------------------------------------------===//
@@ -21,6 +21,8 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "mesh-to-mpi"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -34,138 +36,190 @@ using namespace mlir;
using namespace mlir::mesh;
namespace {
+
+// This pattern converts the mesh.update_halo operation to MPI calls
+struct ConvertUpdateHaloOp
+ : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mlir::mesh::UpdateHaloOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ // Halos are exchanged as 2 blocks per dimension (one for each side: down
+ // and up). It is assumed that the last dim in a default memref is
+ // contiguous, hence iteration starts with the complete halo on the first
+ // dim which should be contiguous (unless the source is not). The size of
+ // the exchanged data will decrease when iterating over dimensions. That's
+ // good because the halos of last dim will be most fragmented.
+ // memref.subview is used to read and write the halo data from and to the
+ // local data. subviews and halos have dynamic and static values, so
+ // OpFoldResults are used whenever possible.
+
+ SymbolTableCollection symbolTableCollection;
+ auto loc = op.getLoc();
+
+ // convert a OpFoldResult into a Value
+ auto toValue = [&rewriter, &loc](OpFoldResult &v) {
+ return v.is<Value>()
+ ? v.get<Value>()
+ : rewriter.create<::mlir::arith::ConstantOp>(
+ loc,
+ rewriter.getIndexAttr(
+ cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+ };
+
+ auto array = op.getInput();
+ auto rank = array.getType().getRank();
+ auto mesh = op.getMesh();
+ auto meshOp = getMesh(op, symbolTableCollection);
+ auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
+ op.getDynamicHaloSizes(), rewriter);
+ // subviews need Index values
+ for (auto &sz : haloSizes) {
+ if (sz.is<Value>()) {
+ sz = rewriter
+ .create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
+ sz.get<Value>())
+ .getResult();
+ }
+ }
+
+ // most of the offset/size/stride data is the same for all dims
+ SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> shape(rank);
+ // we need the actual shape to compute offsets and sizes
+ for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
+ if (ShapedType::isDynamic(s)) {
+ shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
+ } else {
+ shape[i] = rewriter.getIndexAttr(s);
+ }
+ }
+
+ auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
+ auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+ auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
+ auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+ SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+ rewriter.getIndexType());
+ auto myMultiIndex =
+ rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
+ .getResult();
+ // halo sizes are provided for split dimensions only
+ auto currHaloDim = 0;
+
+ for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
+ if (splitAxes.empty()) {
+ continue;
+ }
+ // Get the linearized ids of the neighbors (down and up) for the
+ // given split
+ auto tmp = rewriter
+ .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
+ splitAxes)
+ .getResults();
+ // MPI operates on i32...
+ Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), tmp[0]),
+ rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), tmp[1])};
+ // store for later
+ auto orgDimSize = shape[dim];
+ // this dim's offset to the start of the upper halo
+ auto upperOffset = rewriter.create<arith::SubIOp>(
+ loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
+
+ // Make sure we send/recv in a way that does not lead to a dead-lock.
+ // The current approach is by far not optimal, this should be at least
+ // be a red-black pattern or using MPI_sendrecv.
+ // Also, buffers should be re-used.
+ // Still using temporary contiguous buffers for MPI communication...
+ // Still yielding a "serialized" communication pattern...
+ auto genSendRecv = [&](auto dim, bool upperHalo) {
+ auto orgOffset = offsets[dim];
+ shape[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
+ : haloSizes[currHaloDim * 2];
+ // Check if we need to send and/or receive
+ // Processes on the mesh borders have only one neighbor
+ auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
+ auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+ auto hasFrom = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, from, zero);
+ auto hasTo = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, to, zero);
+ auto buffer = rewriter.create<memref::AllocOp>(
+ loc, shape, array.getType().getElementType());
+ // if has neighbor: copy halo data from array to buffer and send
+ rewriter.create<scf::IfOp>(
+ loc, hasTo, [&](OpBuilder &builder, Location loc) {
+ offsets[dim] = upperHalo ? OpFoldResult(builder.getIndexAttr(0))
+ : OpFoldResult(upperOffset);
+ auto subview = builder.create<memref::SubViewOp>(
+ loc, array, offsets, shape, strides);
+ builder.create<memref::CopyOp>(loc, subview, buffer);
+ builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
+ builder.create<scf::YieldOp>(loc);
+ });
+ // if has neighbor: receive halo data into buffer and copy to array
+ rewriter.create<scf::IfOp>(
+ loc, hasFrom, [&](OpBuilder &builder, Location loc) {
+ offsets[dim] = upperHalo ? OpFoldResult(upperOffset)
+ : OpFoldResult(builder.getIndexAttr(0));
+ builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
+ auto subview = builder.create<memref::SubViewOp>(
+ loc, array, offsets, shape, strides);
+ builder.create<memref::CopyOp>(loc, buffer, subview);
+ builder.create<scf::YieldOp>(loc);
+ });
+ rewriter.create<memref::DeallocOp>(loc, buffer);
+ offsets[dim] = orgOffset;
+ };
+
+ genSendRecv(dim, false);
+ genSendRecv(dim, true);
+
+ // prepare shape and offsets for next split dim
+ auto _haloSz =
+ rewriter
+ .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
+ toValue(haloSizes[currHaloDim * 2 + 1]))
+ .getResult();
+ // the shape for next halo excludes the halo on both ends for the
+ // current dim
+ shape[dim] =
+ rewriter.create<arith::SubIOp>(loc, toValue(orgDimSize), _haloSz)
+ .getResult();
+ // the offsets for next halo starts after the down halo for the
+ // current dim
+ offsets[dim] = haloSizes[currHaloDim * 2];
+ // on to next halo
+ ++currHaloDim;
+ }
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+};
+
struct ConvertMeshToMPIPass
: public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
using Base::Base;
/// Run the dialect converter on the module.
void runOnOperation() override {
- getOperation()->walk([&](UpdateHaloOp op) {
- SymbolTableCollection symbolTableCollection;
- OpBuilder builder(op);
- auto loc = op.getLoc();
-
- auto toValue = [&builder, &loc](OpFoldResult &v) {
- return v.is<Value>()
- ? v.get<Value>()
- : builder.create<::mlir::arith::ConstantOp>(
- loc,
- builder.getIndexAttr(
- cast<IntegerAttr>(v.get<Attribute>()).getInt()));
- };
+ auto *ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
- auto array = op.getInput();
- auto rank = array.getType().getRank();
- auto mesh = op.getMesh();
- auto meshOp = getMesh(op, symbolTableCollection);
- auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
- op.getDynamicHaloSizes(), builder);
- for (auto &sz : haloSizes) {
- if (sz.is<Value>()) {
- sz = builder
- .create<arith::IndexCastOp>(loc, builder.getIndexType(),
- sz.get<Value>())
- .getResult();
- }
- }
-
- SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
- SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
- SmallVector<OpFoldResult> shape(rank);
- for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
- if (ShapedType::isDynamic(s)) {
- shape[i] = builder.create<memref::DimOp>(loc, array, s).getResult();
- } else {
- shape[i] = builder.getIndexAttr(s);
- }
- }
+ patterns.insert<ConvertUpdateHaloOp>(ctx);
- auto tagAttr = builder.getI32IntegerAttr(91); // whatever
- auto tag = builder.create<::mlir::arith::ConstantOp>(loc, tagAttr);
- auto zeroAttr = builder.getI32IntegerAttr(0); // whatever
- auto zero = builder.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
- SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
- builder.getIndexType());
- auto myMultiIndex =
- builder.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
- .getResult();
- auto currHaloDim = 0;
-
- for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
- if (!splitAxes.empty()) {
- auto tmp = builder
- .create<NeighborsLinearIndicesOp>(
- loc, mesh, myMultiIndex, splitAxes)
- .getResults();
- Value neighbourIDs[2] = {builder.create<arith::IndexCastOp>(
- loc, builder.getI32Type(), tmp[0]),
- builder.create<arith::IndexCastOp>(
- loc, builder.getI32Type(), tmp[1])};
- auto orgDimSize = shape[dim];
- auto upperOffset = builder.create<arith::SubIOp>(
- loc, toValue(shape[dim]), toValue(haloSizes[dim * 2 + 1]));
-
- // make sure we send/recv in a way that does not lead to a dead-lock
- // This is by far not optimal, this should be at least MPI_sendrecv
- // and - probably even more importantly - buffers should be re-used
- // Currently using temporary, contiguous buffer for MPI communication
- auto genSendRecv = [&](auto dim, bool upperHalo) {
- auto orgOffset = offsets[dim];
- shape[dim] =
- upperHalo ? haloSizes[dim * 2 + 1] : haloSizes[dim * 2];
- auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
- auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
- auto hasFrom = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, from, zero);
- auto hasTo = builder.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, to, zero);
- auto buffer = builder.create<memref::AllocOp>(
- loc, shape, array.getType().getElementType());
- builder.create<scf::IfOp>(
- loc, hasTo, [&](OpBuilder &builder, Location loc) {
- offsets[dim] = upperHalo
- ? OpFoldResult(builder.getIndexAttr(0))
- : OpFoldResult(upperOffset);
- auto subview = builder.create<memref::SubViewOp>(
- loc, array, offsets, shape, strides);
- builder.create<memref::CopyOp>(loc, subview, buffer);
- builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag,
- to);
- builder.create<scf::YieldOp>(loc);
- });
- builder.create<scf::IfOp>(
- loc, hasFrom, [&](OpBuilder &builder, Location loc) {
- offsets[dim] = upperHalo
- ? OpFoldResult(upperOffset)
- : OpFoldResult(builder.getIndexAttr(0));
- builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag,
- from);
- auto subview = builder.create<memref::SubViewOp>(
- loc, array, offsets, shape, strides);
- builder.create<memref::CopyOp>(loc, buffer, subview);
- builder.create<scf::YieldOp>(loc);
- });
- builder.create<memref::DeallocOp>(loc, buffer);
- offsets[dim] = orgOffset;
- };
-
- genSendRecv(dim, false);
- genSendRecv(dim, true);
-
- shape[dim] = builder
- .create<arith::SubIOp>(
- loc, toValue(orgDimSize),
- builder
- .create<arith::AddIOp>(
- loc, toValue(haloSizes[dim * 2]),
- toValue(haloSizes[dim * 2 + 1]))
- .getResult())
- .getResult();
- offsets[dim] = haloSizes[dim * 2];
- ++currHaloDim;
- }
- }
- });
+ (void)mlir::applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns));
}
};
-} // namespace
\ No newline at end of file
+
+} // namespace
+
+// Create a pass that convert Mesh to MPI
+std::unique_ptr<::mlir::OperationPass<void>> createConvertMeshToMPIPass() {
+ return std::make_unique<ConvertMeshToMPIPass>();
+}
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 9ef826ca0cdace..5f563364272d96 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -1,34 +1,173 @@
-// RUN: mlir-opt %s -split-input-file -convert-mesh-to-mpi | FileCheck %s
+// RUN: mlir-opt %s -convert-mesh-to-mpi | FileCheck %s
// CHECK: mesh.mesh @mesh0
mesh.mesh @mesh0(shape = 2x2x4)
-// -----
-
-// CHECK-LABEL: func @update_halo
-func.func @update_halo_1d(
- // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+// CHECK-LABEL: func @update_halo_1d_first
+func.func @update_halo_1d_first(
+ // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
%arg0 : memref<12x12xi8>) {
- // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i64
- // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
- // CHECK-SAME: split_axes = {{\[\[}}0]]
- // CHECK-SAME: halo_sizes = [2, %c2_i64] : memref<12x12xi8>
- %c2 = arith.constant 2 : i64
+ // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+ // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+ // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+ // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+ // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+ // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8>
+ // CHECK-NEXT: scf.if [[v3]] {
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
+ // CHECK-NEXT: }
+ // CHECK-NEXT: scf.if [[v2]] {
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x12xi8>
+ // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8>
+ // CHECK-NEXT: scf.if [[v5]] {
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1]>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1]>> to memref<3x12xi8>
+ // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32
+ // CHECK-NEXT: }
+ // CHECK-NEXT: scf.if [[v4]] {
+ // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<3x12xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8>
+ // CHECK-NEXT: return
mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
- halo_sizes = [2, %c2] : memref<12x12xi8>
+ halo_sizes = [2, 3] : memref<12x12xi8>
+ return
+}
+
+// CHECK-LABEL: func @update_halo_1d_second
+func.func @update_halo_1d_second(
+ // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
+ %arg0 : memref<12x12xi8>) {
+ // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+ // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+ // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index
+ // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+ // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+ // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8>
+ // CHECK-NEXT: scf.if [[v3]] {
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32
+ // CHECK-NEXT: }
+ // CHECK-NEXT: scf.if [[v2]] {
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>>
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8>
+ // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8>
+ // CHECK-NEXT: scf.if [[v5]] {
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1]>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1]>> to memref<12x3xi8>
+ // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32
+ // CHECK-NEXT: }
+ // CHECK-NEXT: scf.if [[v4]] {
+ // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8>
+ // CHECK-NEXT: return
+ mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]]
+ halo_sizes = [2, 3] : memref<12x12xi8>
return
}
+// CHECK-LABEL: func @update_halo_2d
func.func @update_halo_2d(
- // CHECK-SAME: %[[ARG:.*]]: memref<12x12xi8>
+ // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
%arg0 : memref<12x12xi8>) {
- %c2 = arith.constant 2 : i64
- // CHECK-NEXT: mesh.update_halo %[[ARG]] on @mesh0
- // CHECK-SAME: split_axes = {{\[\[}}0], [1]]
- // CHECK-SAME: halo_sizes = [2, 2, %[[C2]], 2]
- // CHECK-SAME: target_halo_sizes = [3, 3, 2, 2] : memref<12x12xi8>
+ // CHECK-NEXT: [[vc8:%.*]] = arith.constant 8 : index
+ // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+ // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index
+ // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+ // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+ // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+ // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+ // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<1x12xi8>
+ // CHECK-NEXT: scf.if [[v3]] {
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<1x12xi8>, i32, i32
+ // CHECK-NEXT: }
+ // CHECK-NEXT: scf.if [[v2]] {
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<1x12xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>>
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<1x12xi8>
+ // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<2x12xi8>
+ // CHECK-NEXT: scf.if [[v5]] {
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<2x12xi8, strided<[12, 1]>> to memref<2x12xi8>
+ // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+ // CHECK-NEXT: }
+ // CHECK-NEXT: scf.if [[v4]] {
+ // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<2x12xi8>
+ // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index
+ // CHECK-NEXT: [[v6:%.*]] = arith.index_cast [[vdown_linear_idx_1]] : index to i32
+ // CHECK-NEXT: [[v7:%.*]] = arith.index_cast [[vup_linear_idx_2]] : index to i32
+ // CHECK-NEXT: [[v8:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[v9:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc([[vc9]]) : memref<?x3xi8>
+ // CHECK-NEXT: scf.if [[v9]] {
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_3]] : memref<?x3xi8, strided<[12, 1], offset: ?>> to memref<?x3xi8>
+ // CHECK-NEXT: mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref<?x3xi8>, i32, i32
+ // CHECK-NEXT: }
+ // CHECK-NEXT: scf.if [[v8]] {
+ // CHECK-NEXT: mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref<?x3xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
+ // CHECK-NEXT: memref.copy [[valloc_3]], [[vsubview]] : memref<?x3xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref<?x3xi8>
+ // CHECK-NEXT: [[v10:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[v11:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc([[vc9]]) : memref<?x4xi8>
+ // CHECK-NEXT: scf.if [[v11]] {
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: 12>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_4]] : memref<?x4xi8, strided<[12, 1], offset: 12>> to memref<?x4xi8>
+ // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref<?x4xi8>, i32, i32
+ // CHECK-NEXT: }
+ // CHECK-NEXT: scf.if [[v10]] {
+ // CHECK-NEXT: mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref<?x4xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[valloc_4]], [[vsubview]] : memref<?x4xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<?x4xi8>
+ // CHECK-NEXT: return
mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
- halo_sizes = [2, 2, %c2, 2] target_halo_sizes = [3, 3, 2, 2]
- : memref<12x12xi8>
+ halo_sizes = [1, 2, 3, 4]
+ : memref<12x12xi8>
return
}
>From dee4faf157ec1966e16e9ef858cf123e780cfaaa Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 20 Aug 2024 19:23:13 +0200
Subject: [PATCH 3/4] fixed corner halos by reversing data-exchanges from high
to low dims
---
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 91 +++++-----
.../MeshToMPI/convert-mesh-to-mpi.mlir | 161 +++++++++---------
2 files changed, 137 insertions(+), 115 deletions(-)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 42d885a109ee79..9cf9458ce2b687 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -70,6 +70,7 @@ struct ConvertUpdateHaloOp
auto array = op.getInput();
auto rank = array.getType().getRank();
+ auto opSplitAxes = op.getSplitAxes().getAxes();
auto mesh = op.getMesh();
auto meshOp = getMesh(op, symbolTableCollection);
auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
@@ -87,32 +88,54 @@ struct ConvertUpdateHaloOp
// most of the offset/size/stride data is the same for all dims
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
- SmallVector<OpFoldResult> shape(rank);
+ SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
+ auto currHaloDim = -1; // halo sizes are provided for split dimensions only
// we need the actual shape to compute offsets and sizes
- for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
+ for (auto i = 0; i < rank; ++i) {
+ auto s = array.getType().getShape()[i];
if (ShapedType::isDynamic(s)) {
shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
} else {
shape[i] = rewriter.getIndexAttr(s);
}
+
+ if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
+ ++currHaloDim;
+ // the offsets for lower dim sstarts after their down halo
+ offsets[i] = haloSizes[currHaloDim * 2];
+
+ // prepare shape and offsets of highest dim's halo exchange
+ auto _haloSz =
+ rewriter
+ .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
+ toValue(haloSizes[currHaloDim * 2 + 1]))
+ .getResult();
+ // the halo shape of lower dims exlude the halos
+ dimSizes[i] =
+ rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
+ .getResult();
+ } else {
+ dimSizes[i] = shape[i];
+ }
}
auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+
SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
rewriter.getIndexType());
auto myMultiIndex =
rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
.getResult();
- // halo sizes are provided for split dimensions only
- auto currHaloDim = 0;
-
- for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
+ // traverse all split axes from high to low dim
+ for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
+ auto splitAxes = opSplitAxes[dim];
if (splitAxes.empty()) {
continue;
}
+ assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
// Get the linearized ids of the neighbors (down and up) for the
// given split
auto tmp = rewriter
@@ -124,11 +147,13 @@ struct ConvertUpdateHaloOp
loc, rewriter.getI32Type(), tmp[0]),
rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(), tmp[1])};
- // store for later
- auto orgDimSize = shape[dim];
- // this dim's offset to the start of the upper halo
- auto upperOffset = rewriter.create<arith::SubIOp>(
+
+ auto lowerRecvOffset = rewriter.getIndexAttr(0);
+ auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
+ auto upperRecvOffset = rewriter.create<arith::SubIOp>(
loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
+ auto upperSendOffset = rewriter.create<arith::SubIOp>(
+ loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
// Make sure we send/recv in a way that does not lead to a dead-lock.
// The current approach is by far not optimal, this should be at least
@@ -136,10 +161,10 @@ struct ConvertUpdateHaloOp
// Also, buffers should be re-used.
// Still using temporary contiguous buffers for MPI communication...
// Still yielding a "serialized" communication pattern...
- auto genSendRecv = [&](auto dim, bool upperHalo) {
+ auto genSendRecv = [&](bool upperHalo) {
auto orgOffset = offsets[dim];
- shape[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
- : haloSizes[currHaloDim * 2];
+ dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
+ : haloSizes[currHaloDim * 2];
// Check if we need to send and/or receive
// Processes on the mesh borders have only one neighbor
auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
@@ -149,14 +174,14 @@ struct ConvertUpdateHaloOp
auto hasTo = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, to, zero);
auto buffer = rewriter.create<memref::AllocOp>(
- loc, shape, array.getType().getElementType());
+ loc, dimSizes, array.getType().getElementType());
// if has neighbor: copy halo data from array to buffer and send
rewriter.create<scf::IfOp>(
loc, hasTo, [&](OpBuilder &builder, Location loc) {
- offsets[dim] = upperHalo ? OpFoldResult(builder.getIndexAttr(0))
- : OpFoldResult(upperOffset);
+ offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
+ : OpFoldResult(upperSendOffset);
auto subview = builder.create<memref::SubViewOp>(
- loc, array, offsets, shape, strides);
+ loc, array, offsets, dimSizes, strides);
builder.create<memref::CopyOp>(loc, subview, buffer);
builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
builder.create<scf::YieldOp>(loc);
@@ -164,11 +189,11 @@ struct ConvertUpdateHaloOp
// if has neighbor: receive halo data into buffer and copy to array
rewriter.create<scf::IfOp>(
loc, hasFrom, [&](OpBuilder &builder, Location loc) {
- offsets[dim] = upperHalo ? OpFoldResult(upperOffset)
- : OpFoldResult(builder.getIndexAttr(0));
+ offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
+ : OpFoldResult(lowerRecvOffset);
builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
auto subview = builder.create<memref::SubViewOp>(
- loc, array, offsets, shape, strides);
+ loc, array, offsets, dimSizes, strides);
builder.create<memref::CopyOp>(loc, buffer, subview);
builder.create<scf::YieldOp>(loc);
});
@@ -176,25 +201,15 @@ struct ConvertUpdateHaloOp
offsets[dim] = orgOffset;
};
- genSendRecv(dim, false);
- genSendRecv(dim, true);
-
- // prepare shape and offsets for next split dim
- auto _haloSz =
- rewriter
- .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
- toValue(haloSizes[currHaloDim * 2 + 1]))
- .getResult();
- // the shape for next halo excludes the halo on both ends for the
- // current dim
- shape[dim] =
- rewriter.create<arith::SubIOp>(loc, toValue(orgDimSize), _haloSz)
- .getResult();
- // the offsets for next halo starts after the down halo for the
- // current dim
- offsets[dim] = haloSizes[currHaloDim * 2];
+ genSendRecv(false);
+ genSendRecv(true);
+
+ // the shape for lower dims include higher dims' halos
+ dimSizes[dim] = shape[dim];
+ // -> the offset for higher dims is always 0
+ offsets[dim] = rewriter.getIndexAttr(0);
// on to next halo
- ++currHaloDim;
+ --currHaloDim;
}
rewriter.eraseOp(op);
return mlir::success();
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 5f563364272d96..c3b0dc12e6d746 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -6,8 +6,10 @@ mesh.mesh @mesh0(shape = 2x2x4)
// CHECK-LABEL: func @update_halo_1d_first
func.func @update_halo_1d_first(
// CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
- %arg0 : memref<12x12xi8>) {
+ %arg0 : memref<12x12xi8>) {
+ // CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index
// CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+ // CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index
// CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
// CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
// CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
@@ -18,7 +20,7 @@ func.func @update_halo_1d_first(
// CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8>
// CHECK-NEXT: scf.if [[v3]] {
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc7]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
// CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
// CHECK-NEXT: }
@@ -32,8 +34,8 @@ func.func @update_halo_1d_first(
// CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
// CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8>
// CHECK-NEXT: scf.if [[v5]] {
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1]>>
- // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1]>> to memref<3x12xi8>
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc2]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1], offset: ?>> to memref<3x12xi8>
// CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32
// CHECK-NEXT: }
// CHECK-NEXT: scf.if [[v4]] {
@@ -42,9 +44,9 @@ func.func @update_halo_1d_first(
// CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
// CHECK-NEXT: }
// CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8>
- // CHECK-NEXT: return
mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
halo_sizes = [2, 3] : memref<12x12xi8>
+ // CHECK-NEXT: return
return
}
@@ -52,44 +54,46 @@ func.func @update_halo_1d_first(
func.func @update_halo_1d_second(
// CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
%arg0 : memref<12x12xi8>) {
- // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
- // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
- // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
- // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
- // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index
- // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
- // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
- // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
- // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
- // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8>
- // CHECK-NEXT: scf.if [[v3]] {
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>>
- // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8>
- // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32
- // CHECK-NEXT: }
- // CHECK-NEXT: scf.if [[v2]] {
- // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>>
- // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>>
- // CHECK-NEXT: }
- // CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8>
- // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
- // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
- // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8>
- // CHECK-NEXT: scf.if [[v5]] {
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1]>>
- // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1]>> to memref<12x3xi8>
- // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32
- // CHECK-NEXT: }
- // CHECK-NEXT: scf.if [[v4]] {
- // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
- // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
- // CHECK-NEXT: }
- // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8>
- // CHECK-NEXT: return
+ //CHECK-NEXT: [[vc7:%.*]] = arith.constant 7 : index
+ //CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+ //CHECK-NEXT: [[vc2:%.*]] = arith.constant 2 : index
+ //CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ //CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+ //CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ //CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [3] : index, index
+ //CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+ //CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+ //CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+ //CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+ //CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<12x2xi8>
+ //CHECK-NEXT: scf.if [[v3]] {
+ //CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c7] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1], offset: ?>>
+ //CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<12x2xi8, strided<[12, 1], offset: ?>> to memref<12x2xi8>
+ //CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<12x2xi8>, i32, i32
+ //CHECK-NEXT: }
+ //CHECK-NEXT: scf.if [[v2]] {
+ //CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<12x2xi8>, i32, i32
+ //CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [12, 2] [1, 1] : memref<12x12xi8> to memref<12x2xi8, strided<[12, 1]>>
+ //CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<12x2xi8> to memref<12x2xi8, strided<[12, 1]>>
+ //CHECK-NEXT: }
+ //CHECK-NEXT: memref.dealloc [[valloc]] : memref<12x2xi8>
+ //CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+ //CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+ //CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<12x3xi8>
+ //CHECK-NEXT: scf.if [[v5]] {
+ //CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c2] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+ //CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<12x3xi8, strided<[12, 1], offset: ?>> to memref<12x3xi8>
+ //CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<12x3xi8>, i32, i32
+ //CHECK-NEXT: }
+ //CHECK-NEXT: scf.if [[v4]] {
+ //CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<12x3xi8>, i32, i32
+ //CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, %c9] [12, 3] [1, 1] : memref<12x12xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+ //CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<12x3xi8> to memref<12x3xi8, strided<[12, 1], offset: ?>>
+ //CHECK-NEXT: }
+ //CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<12x3xi8>
mesh.update_halo %arg0 on @mesh0 split_axes = [[], [3]]
halo_sizes = [2, 3] : memref<12x12xi8>
+ //CHECK-NEXT: return
return
}
@@ -97,77 +101,80 @@ func.func @update_halo_1d_second(
func.func @update_halo_2d(
// CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
%arg0 : memref<12x12xi8>) {
+ // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index
+ // CHECK-NEXT: [[vc1:%.*]] = arith.constant 1 : index
+ // CHECK-NEXT: [[vc5:%.*]] = arith.constant 5 : index
// CHECK-NEXT: [[vc8:%.*]] = arith.constant 8 : index
+ // CHECK-NEXT: [[vc3:%.*]] = arith.constant 3 : index
// CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
- // CHECK-NEXT: [[vc10:%.*]] = arith.constant 10 : index
// CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
// CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
// CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
- // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+ // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index
// CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
// CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
// CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
// CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
- // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<1x12xi8>
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc([[vc9]]) : memref<?x3xi8>
// CHECK-NEXT: scf.if [[v3]] {
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>>
- // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8>
- // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<1x12xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c5] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<?x3xi8, strided<[12, 1], offset: ?>> to memref<?x3xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<?x3xi8>, i32, i32
// CHECK-NEXT: }
// CHECK-NEXT: scf.if [[v2]] {
- // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<1x12xi8>, i32, i32
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>>
- // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>>
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<?x3xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<?x3xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
// CHECK-NEXT: }
- // CHECK-NEXT: memref.dealloc [[valloc]] : memref<1x12xi8>
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<?x3xi8>
// CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
// CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
- // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<2x12xi8>
+ // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc([[vc9]]) : memref<?x4xi8>
// CHECK-NEXT: scf.if [[v5]] {
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
- // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<2x12xi8, strided<[12, 1]>> to memref<2x12xi8>
- // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c3] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<?x4xi8, strided<[12, 1], offset: ?>> to memref<?x4xi8>
+ // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<?x4xi8>, i32, i32
// CHECK-NEXT: }
// CHECK-NEXT: scf.if [[v4]] {
- // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
- // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<?x4xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<?x4xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
// CHECK-NEXT: }
- // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<2x12xi8>
- // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [1] : index, index
+ // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<?x4xi8>
+ // CHECK-NEXT: [[vdown_linear_idx_1:%.*]], [[vup_linear_idx_2:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
// CHECK-NEXT: [[v6:%.*]] = arith.index_cast [[vdown_linear_idx_1]] : index to i32
// CHECK-NEXT: [[v7:%.*]] = arith.index_cast [[vup_linear_idx_2]] : index to i32
// CHECK-NEXT: [[v8:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
// CHECK-NEXT: [[v9:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
- // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc([[vc9]]) : memref<?x3xi8>
+ // CHECK-NEXT: [[valloc_3:%.*]] = memref.alloc() : memref<1x12xi8>
// CHECK-NEXT: scf.if [[v9]] {
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: ?>>
- // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_3]] : memref<?x3xi8, strided<[12, 1], offset: ?>> to memref<?x3xi8>
- // CHECK-NEXT: mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref<?x3xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_3]] : memref<1x12xi8, strided<[12, 1], offset: ?>> to memref<1x12xi8>
+ // CHECK-NEXT: mpi.send([[valloc_3]], [[vc91_i32]], [[v6]]) : memref<1x12xi8>, i32, i32
// CHECK-NEXT: }
// CHECK-NEXT: scf.if [[v8]] {
- // CHECK-NEXT: mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref<?x3xi8>, i32, i32
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 3] [1, 1] : memref<12x12xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
- // CHECK-NEXT: memref.copy [[valloc_3]], [[vsubview]] : memref<?x3xi8> to memref<?x3xi8, strided<[12, 1], offset: 12>>
+ // CHECK-NEXT: mpi.recv([[valloc_3]], [[vc91_i32]], [[v7]]) : memref<1x12xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [1, 12] [1, 1] : memref<12x12xi8> to memref<1x12xi8, strided<[12, 1]>>
+ // CHECK-NEXT: memref.copy [[valloc_3]], [[vsubview]] : memref<1x12xi8> to memref<1x12xi8, strided<[12, 1]>>
// CHECK-NEXT: }
- // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref<?x3xi8>
+ // CHECK-NEXT: memref.dealloc [[valloc_3]] : memref<1x12xi8>
// CHECK-NEXT: [[v10:%.*]] = arith.cmpi sge, [[v6]], [[vc0_i32]] : i32
// CHECK-NEXT: [[v11:%.*]] = arith.cmpi sge, [[v7]], [[vc0_i32]] : i32
- // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc([[vc9]]) : memref<?x4xi8>
+ // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<2x12xi8>
// CHECK-NEXT: scf.if [[v11]] {
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 0] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: 12>>
- // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_4]] : memref<?x4xi8, strided<[12, 1], offset: 12>> to memref<?x4xi8>
- // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref<?x4xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc1]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_4]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
+ // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[v7]]) : memref<2x12xi8>, i32, i32
// CHECK-NEXT: }
// CHECK-NEXT: scf.if [[v10]] {
- // CHECK-NEXT: mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref<?x4xi8>, i32, i32
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, %c8] [[[vc9]], 4] [1, 1] : memref<12x12xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
- // CHECK-NEXT: memref.copy [[valloc_4]], [[vsubview]] : memref<?x4xi8> to memref<?x4xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: mpi.recv([[valloc_4]], [[vc91_i32]], [[v6]]) : memref<2x12xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc10]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[valloc_4]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
// CHECK-NEXT: }
- // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<?x4xi8>
- // CHECK-NEXT: return
+ // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<2x12xi8>
mesh.update_halo %arg0 on @mesh0 split_axes = [[0], [1]]
halo_sizes = [1, 2, 3, 4]
: memref<12x12xi8>
+ // CHECK-NEXT: return
return
}
>From 443b9c0e5cda6a3b216be97d685bf047ffd914f7 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 21 Aug 2024 12:25:08 +0200
Subject: [PATCH 4/4] addressed review comments (docs, formatting)
---
.../mlir/Conversion/MeshToMPI/MeshToMPI.h | 2 +-
mlir/include/mlir/Conversion/Passes.td | 2 +-
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 6 +++---
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 16 +++++++++-------
4 files changed, 14 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
index 6a2c196da45577..b8803f386f7356 100644
--- a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
+++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
@@ -1,4 +1,4 @@
-//===- MeshToMPI.h - Convert Mesh to MPI dialect --*- C++ -*-===//
+//===- MeshToMPI.h - Convert Mesh to MPI dialect ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f9a6f52a22c6ed..381b472895d24b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -877,7 +877,7 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
let summary = "Convert Mesh dialect to MPI dialect.";
let description = [{
This pass converts communication operations
- from the Mesh dialect to operations from the MPI dialect.
+ from the Mesh dialect to the MPI dialect.
}];
let dependentDialects = [
"memref::MemRefDialect",
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 9d1684b78f34f2..1027fc436dcead 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -161,7 +161,7 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary =
- "For given split axes get the linear index the direct neighbor processes.";
+ "For given split axes get the linear indices of the direct neighbor processes.";
let description = [{
Example:
```
@@ -171,8 +171,8 @@ def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
Given `@mesh` with shape `(10, 20, 30)`,
`device` = `(1, 2, 3)`
`$split_axes` = `[1]`
- it returns the linear indices of the processes at positions `(1, 1, 3)`: `633`
- and `(1, 3, 3)`: `693`.
+ returns two indices, `633` and `693`, which correspond to the index of the previous
+ process `(1, 1, 3)`, and the next process `(1, 3, 3) along the split axis `1`.
A negative value is returned if `$device` has no neighbor in the given
direction along the given `split_axes`.
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 9cf9458ce2b687..ea1323e43462cd 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -45,15 +45,17 @@ struct ConvertUpdateHaloOp
mlir::LogicalResult
matchAndRewrite(mlir::mesh::UpdateHaloOp op,
mlir::PatternRewriter &rewriter) const override {
+ // The input/output memref is assumed to be in C memory order.
// Halos are exchanged as 2 blocks per dimension (one for each side: down
- // and up). It is assumed that the last dim in a default memref is
- // contiguous, hence iteration starts with the complete halo on the first
- // dim which should be contiguous (unless the source is not). The size of
- // the exchanged data will decrease when iterating over dimensions. That's
- // good because the halos of last dim will be most fragmented.
+ // and up). For each haloed dimension `d`, the exchanged blocks are
+ // expressed as multi-dimensional subviews. The subviews include potential
+ // halos of higher dimensions `dh > d`, no halos for the lower dimensions
+ // `dl < d` and for dimension `d` the currently exchanged halo only.
+ // By iterating form higher to lower dimensions this also updates the halos
+ // in the 'corners'.
// memref.subview is used to read and write the halo data from and to the
- // local data. subviews and halos have dynamic and static values, so
- // OpFoldResults are used whenever possible.
+ // local data. Because subviews and halos can have mixed dynamic and static
+ // shapes, OpFoldResults are used whenever possible.
SymbolTableCollection symbolTableCollection;
auto loc = op.getLoc();
More information about the Mlir-commits
mailing list