[Mlir-commits] [mlir] [mlir][mesh, MPI] Mesh2mpi (PR #104566)
Frank Schlimbach
llvmlistbot at llvm.org
Tue Sep 3 00:35:09 PDT 2024
================
@@ -0,0 +1,242 @@
+//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation of Mesh communication ops tp MPI ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "mesh-to-mpi"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+
+// This pattern converts the mesh.update_halo operation to MPI calls
+struct ConvertUpdateHaloOp
+ : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mlir::mesh::UpdateHaloOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ // The input/output memref is assumed to be in C memory order.
+ // Halos are exchanged as 2 blocks per dimension (one for each side: down
+ // and up). For each haloed dimension `d`, the exchanged blocks are
+ // expressed as multi-dimensional subviews. The subviews include potential
+ // halos of higher dimensions `dh > d`, no halos for the lower dimensions
+ // `dl < d` and for dimension `d` the currently exchanged halo only.
+ // By iterating form higher to lower dimensions this also updates the halos
+ // in the 'corners'.
+ // memref.subview is used to read and write the halo data from and to the
+ // local data. Because subviews and halos can have mixed dynamic and static
+ // shapes, OpFoldResults are used whenever possible.
+
+ SymbolTableCollection symbolTableCollection;
+ auto loc = op.getLoc();
+
+ // convert a OpFoldResult into a Value
+ auto toValue = [&rewriter, &loc](OpFoldResult &v) {
+ return v.is<Value>()
+ ? v.get<Value>()
+ : rewriter.create<::mlir::arith::ConstantOp>(
+ loc,
+ rewriter.getIndexAttr(
+ cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+ };
+
+ auto array = op.getInput();
+ auto rank = array.getType().getRank();
+ auto opSplitAxes = op.getSplitAxes().getAxes();
+ auto mesh = op.getMesh();
+ auto meshOp = getMesh(op, symbolTableCollection);
+ auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
+ op.getDynamicHaloSizes(), rewriter);
+ // subviews need Index values
+ for (auto &sz : haloSizes) {
+ if (sz.is<Value>()) {
+ sz = rewriter
+ .create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
+ sz.get<Value>())
+ .getResult();
+ }
+ }
+
+ // most of the offset/size/stride data is the same for all dims
+ SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
+ auto currHaloDim = -1; // halo sizes are provided for split dimensions only
+ // we need the actual shape to compute offsets and sizes
+ for (auto i = 0; i < rank; ++i) {
+ auto s = array.getType().getShape()[i];
+ if (ShapedType::isDynamic(s)) {
+ shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
+ } else {
+ shape[i] = rewriter.getIndexAttr(s);
+ }
+
+ if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
+ ++currHaloDim;
+ // the offsets for lower dim sstarts after their down halo
+ offsets[i] = haloSizes[currHaloDim * 2];
+
+ // prepare shape and offsets of highest dim's halo exchange
+ auto _haloSz =
+ rewriter
+ .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
+ toValue(haloSizes[currHaloDim * 2 + 1]))
+ .getResult();
+ // the halo shape of lower dims exlude the halos
+ dimSizes[i] =
+ rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
+ .getResult();
+ } else {
+ dimSizes[i] = shape[i];
+ }
+ }
+
+ auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
----------------
fschlimb wrote:
Since we have only COMM_WORLD this might reduce the risk of tag conflicts (like in multi-threaded cases).
https://github.com/llvm/llvm-project/pull/104566
More information about the Mlir-commits
mailing list