[Mlir-commits] [mlir] [mlir][mesh, MPI] Mesh2mpi (PR #104566)

Frank Schlimbach llvmlistbot at llvm.org
Tue Sep 3 00:35:09 PDT 2024


================
@@ -0,0 +1,242 @@
+//===- MeshToMPI.cpp - Mesh to MPI  dialect conversion -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation of Mesh communication ops tp MPI ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "mesh-to-mpi"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+
+// This pattern converts the mesh.update_halo operation to MPI calls
+struct ConvertUpdateHaloOp
+    : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::mesh::UpdateHaloOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    // The input/output memref is assumed to be in C memory order.
+    // Halos are exchanged as 2 blocks per dimension (one for each side: down
+    // and up). For each haloed dimension `d`, the exchanged blocks are
+    // expressed as multi-dimensional subviews. The subviews include potential
+    // halos of higher dimensions `dh > d`, no halos for the lower dimensions
+    // `dl < d` and for dimension `d` the currently exchanged halo only.
+    // By iterating form higher to lower dimensions this also updates the halos
+    // in the 'corners'.
+    // memref.subview is used to read and write the halo data from and to the
+    // local data. Because subviews and halos can have mixed dynamic and static
+    // shapes, OpFoldResults are used whenever possible.
+
+    SymbolTableCollection symbolTableCollection;
+    auto loc = op.getLoc();
+
+    // convert a OpFoldResult into a Value
+    auto toValue = [&rewriter, &loc](OpFoldResult &v) {
+      return v.is<Value>()
+                 ? v.get<Value>()
+                 : rewriter.create<::mlir::arith::ConstantOp>(
+                       loc,
+                       rewriter.getIndexAttr(
+                           cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+    };
+
+    auto array = op.getInput();
+    auto rank = array.getType().getRank();
+    auto opSplitAxes = op.getSplitAxes().getAxes();
+    auto mesh = op.getMesh();
+    auto meshOp = getMesh(op, symbolTableCollection);
+    auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
+                                    op.getDynamicHaloSizes(), rewriter);
+    // subviews need Index values
+    for (auto &sz : haloSizes) {
+      if (sz.is<Value>()) {
+        sz = rewriter
+                 .create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
+                                             sz.get<Value>())
+                 .getResult();
+      }
+    }
+
+    // most of the offset/size/stride data is the same for all dims
+    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
+    auto currHaloDim = -1; // halo sizes are provided for split dimensions only
+    // we need the actual shape to compute offsets and sizes
+    for (auto i = 0; i < rank; ++i) {
+      auto s = array.getType().getShape()[i];
+      if (ShapedType::isDynamic(s)) {
+        shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
+      } else {
+        shape[i] = rewriter.getIndexAttr(s);
+      }
+
+      if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
+        ++currHaloDim;
+        // the offsets for lower dim sstarts after their down halo
+        offsets[i] = haloSizes[currHaloDim * 2];
+
+        // prepare shape and offsets of highest dim's halo exchange
+        auto _haloSz =
+            rewriter
+                .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
+                                       toValue(haloSizes[currHaloDim * 2 + 1]))
+                .getResult();
+        // the halo shape of lower dims exlude the halos
+        dimSizes[i] =
+            rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
+                .getResult();
+      } else {
+        dimSizes[i] = shape[i];
+      }
+    }
+
+    auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
----------------
fschlimb wrote:

Since we have only COMM_WORLD this might reduce the risk of tag conflicts (like in multi-threaded cases).

https://github.com/llvm/llvm-project/pull/104566


More information about the Mlir-commits mailing list