[Mlir-commits] [mlir] [mlir][mesh, MPI] Mesh2mpi (PR #104566)
Tuomas Kärnä
llvmlistbot at llvm.org
Fri Nov 15 04:05:46 PST 2024
================
@@ -0,0 +1,433 @@
+//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation of Mesh communication ops tp MPI ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "mesh-to-mpi"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+// Create operations converting a linear index to a multi-dimensional index
+static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
+ Value linearIndex,
+ ValueRange dimensions) {
+ int n = dimensions.size();
+ SmallVector<Value> multiIndex(n);
+
+ for (int i = n - 1; i >= 0; --i) {
+ multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
+ if (i > 0) {
+ linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
+ }
+ }
+
+ return multiIndex;
+}
+
+// Create operations converting a multi-dimensional index to a linear index
+Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
+ ValueRange dimensions) {
+
+ auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
+ auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
+
+ for (int i = multiIndex.size() - 1; i >= 0; --i) {
+ auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
+ linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
+ stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
+ }
+
+ return linearIndex;
+}
+
+// This pattern converts the mesh.update_halo operation to MPI calls
----------------
tkarna wrote:
Better docstring: this pattern just converts the process index. Similar issue with the following patterns.
https://github.com/llvm/llvm-project/pull/104566
More information about the Mlir-commits
mailing list