[Mlir-commits] [mlir] [mlir][mesh, MPI] Mesh2mpi (PR #104566)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 16 02:09:41 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Frank Schlimbach (fschlimb)
<details>
<summary>Changes</summary>
Pass for lowering `Mesh` to `MPI`.
Initial commit lowers `UpdateHaloOp` only.
---
Patch is 28.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104566.diff
9 Files Affected:
- (added) mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h (+27)
- (modified) mlir/include/mlir/Conversion/Passes.h (+1)
- (modified) mlir/include/mlir/Conversion/Passes.td (+17)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+33)
- (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
- (added) mlir/lib/Conversion/MeshToMPI/CMakeLists.txt (+22)
- (added) mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp (+225)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+19)
- (added) mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir (+173)
``````````diff
diff --git a/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
new file mode 100644
index 00000000000000..6a2c196da45577
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MeshToMPI/MeshToMPI.h
@@ -0,0 +1,27 @@
+//===- MeshToMPI.h - Convert Mesh to MPI dialect --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Lowers Mesh communication operations (updateHalo, AllGater, ...)
+/// to MPI primitives.
+std::unique_ptr<Pass> createConvertMeshToMPIPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 208f26489d6c39..ad8e98442ab8bc 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -51,6 +51,7 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7bde9e490e4f4e..f9a6f52a22c6ed 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -869,6 +869,23 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
];
}
+//===----------------------------------------------------------------------===//
+// MeshToMPI
+//===----------------------------------------------------------------------===//
+
+def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
+ let summary = "Convert Mesh dialect to MPI dialect.";
+ let description = [{
+ This pass converts communication operations
+ from the Mesh dialect to operations from the MPI dialect.
+ }];
+ let dependentDialects = [
+ "memref::MemRefDialect",
+ "mpi::MPIDialect",
+ "scf::SCFDialect"
+ ];
+}
+
//===----------------------------------------------------------------------===//
// NVVMToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8f696bbc1a0f6e..9d1684b78f34f2 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -155,6 +155,39 @@ def Mesh_ProcessLinearIndexOp : Mesh_Op<"process_linear_index", [
];
}
+def Mesh_NeighborsLinearIndicesOp : Mesh_Op<"neighbors_linear_indices", [
+ Pure,
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+]> {
+ let summary =
+ "For given split axes get the linear index the direct neighbor processes.";
+ let description = [{
+ Example:
+ ```
+ %idx = mesh.neighbor_linear_index on @mesh for $device
+ split_axes = $split_axes : index
+ ```
+ Given `@mesh` with shape `(10, 20, 30)`,
+ `device` = `(1, 2, 3)`
+ `$split_axes` = `[1]`
+ it returns the linear indices of the processes at positions `(1, 1, 3)`: `633`
+ and `(1, 3, 3)`: `693`.
+
+ A negative value is returned if `$device` has no neighbor in the given
+ direction along the given `split_axes`.
+ }];
+ let arguments = (ins FlatSymbolRefAttr:$mesh,
+ Variadic<Index>:$device,
+ Mesh_MeshAxesAttr:$split_axes);
+ let results = (outs Index:$neighbor_down, Index:$neighbor_up);
+ let assemblyFormat = [{
+ `on` $mesh `[` $device `]`
+ `split_axes` `=` $split_axes
+ attr-dict `:` type(results)
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Sharding operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 813f700c5556e1..3ee237f4e62acd 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -41,6 +41,7 @@ add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
+add_subdirectory(MeshToMPI)
add_subdirectory(NVGPUToNVVM)
add_subdirectory(NVVMToLLVM)
add_subdirectory(OpenACCToSCF)
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
new file mode 100644
index 00000000000000..95815a683f6d6a
--- /dev/null
+++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_mlir_conversion_library(MLIRMeshToMPI
+ MeshToMPI.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRIR
+ MLIRLinalgTransforms
+ MLIRMemRefDialect
+ MLIRPass
+ MLIRMeshDialect
+ MLIRMPIDialect
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
new file mode 100644
index 00000000000000..42d885a109ee79
--- /dev/null
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -0,0 +1,225 @@
+//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation of Mesh communication ops tp MPI ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "mesh-to-mpi"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+
+// This pattern converts the mesh.update_halo operation to MPI calls
+struct ConvertUpdateHaloOp
+ : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mlir::mesh::UpdateHaloOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ // Halos are exchanged as 2 blocks per dimension (one for each side: down
+ // and up). It is assumed that the last dim in a default memref is
+ // contiguous, hence iteration starts with the complete halo on the first
+ // dim which should be contiguous (unless the source is not). The size of
+ // the exchanged data will decrease when iterating over dimensions. That's
+ // good because the halos of last dim will be most fragmented.
+ // memref.subview is used to read and write the halo data from and to the
+ // local data. subviews and halos have dynamic and static values, so
+ // OpFoldResults are used whenever possible.
+
+ SymbolTableCollection symbolTableCollection;
+ auto loc = op.getLoc();
+
+ // convert a OpFoldResult into a Value
+ auto toValue = [&rewriter, &loc](OpFoldResult &v) {
+ return v.is<Value>()
+ ? v.get<Value>()
+ : rewriter.create<::mlir::arith::ConstantOp>(
+ loc,
+ rewriter.getIndexAttr(
+ cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+ };
+
+ auto array = op.getInput();
+ auto rank = array.getType().getRank();
+ auto mesh = op.getMesh();
+ auto meshOp = getMesh(op, symbolTableCollection);
+ auto haloSizes = getMixedValues(op.getStaticHaloSizes(),
+ op.getDynamicHaloSizes(), rewriter);
+ // subviews need Index values
+ for (auto &sz : haloSizes) {
+ if (sz.is<Value>()) {
+ sz = rewriter
+ .create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
+ sz.get<Value>())
+ .getResult();
+ }
+ }
+
+ // most of the offset/size/stride data is the same for all dims
+ SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> shape(rank);
+ // we need the actual shape to compute offsets and sizes
+ for (auto [i, s] : llvm::enumerate(array.getType().getShape())) {
+ if (ShapedType::isDynamic(s)) {
+ shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
+ } else {
+ shape[i] = rewriter.getIndexAttr(s);
+ }
+ }
+
+ auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
+ auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+ auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
+ auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+ SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+ rewriter.getIndexType());
+ auto myMultiIndex =
+ rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
+ .getResult();
+ // halo sizes are provided for split dimensions only
+ auto currHaloDim = 0;
+
+ for (auto [dim, splitAxes] : llvm::enumerate(op.getSplitAxes())) {
+ if (splitAxes.empty()) {
+ continue;
+ }
+ // Get the linearized ids of the neighbors (down and up) for the
+ // given split
+ auto tmp = rewriter
+ .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
+ splitAxes)
+ .getResults();
+ // MPI operates on i32...
+ Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), tmp[0]),
+ rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), tmp[1])};
+ // store for later
+ auto orgDimSize = shape[dim];
+ // this dim's offset to the start of the upper halo
+ auto upperOffset = rewriter.create<arith::SubIOp>(
+ loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
+
+ // Make sure we send/recv in a way that does not lead to a dead-lock.
+ // The current approach is by far not optimal, this should be at least
+ // be a red-black pattern or using MPI_sendrecv.
+ // Also, buffers should be re-used.
+ // Still using temporary contiguous buffers for MPI communication...
+ // Still yielding a "serialized" communication pattern...
+ auto genSendRecv = [&](auto dim, bool upperHalo) {
+ auto orgOffset = offsets[dim];
+ shape[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
+ : haloSizes[currHaloDim * 2];
+ // Check if we need to send and/or receive
+ // Processes on the mesh borders have only one neighbor
+ auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
+ auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+ auto hasFrom = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, from, zero);
+ auto hasTo = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, to, zero);
+ auto buffer = rewriter.create<memref::AllocOp>(
+ loc, shape, array.getType().getElementType());
+ // if has neighbor: copy halo data from array to buffer and send
+ rewriter.create<scf::IfOp>(
+ loc, hasTo, [&](OpBuilder &builder, Location loc) {
+ offsets[dim] = upperHalo ? OpFoldResult(builder.getIndexAttr(0))
+ : OpFoldResult(upperOffset);
+ auto subview = builder.create<memref::SubViewOp>(
+ loc, array, offsets, shape, strides);
+ builder.create<memref::CopyOp>(loc, subview, buffer);
+ builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to);
+ builder.create<scf::YieldOp>(loc);
+ });
+ // if has neighbor: receive halo data into buffer and copy to array
+ rewriter.create<scf::IfOp>(
+ loc, hasFrom, [&](OpBuilder &builder, Location loc) {
+ offsets[dim] = upperHalo ? OpFoldResult(upperOffset)
+ : OpFoldResult(builder.getIndexAttr(0));
+ builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from);
+ auto subview = builder.create<memref::SubViewOp>(
+ loc, array, offsets, shape, strides);
+ builder.create<memref::CopyOp>(loc, buffer, subview);
+ builder.create<scf::YieldOp>(loc);
+ });
+ rewriter.create<memref::DeallocOp>(loc, buffer);
+ offsets[dim] = orgOffset;
+ };
+
+ genSendRecv(dim, false);
+ genSendRecv(dim, true);
+
+ // prepare shape and offsets for next split dim
+ auto _haloSz =
+ rewriter
+ .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
+ toValue(haloSizes[currHaloDim * 2 + 1]))
+ .getResult();
+ // the shape for next halo excludes the halo on both ends for the
+ // current dim
+ shape[dim] =
+ rewriter.create<arith::SubIOp>(loc, toValue(orgDimSize), _haloSz)
+ .getResult();
+ // the offsets for next halo starts after the down halo for the
+ // current dim
+ offsets[dim] = haloSizes[currHaloDim * 2];
+ // on to next halo
+ ++currHaloDim;
+ }
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+};
+
+struct ConvertMeshToMPIPass
+ : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
+ using Base::Base;
+
+ /// Run the dialect converter on the module.
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+
+ patterns.insert<ConvertUpdateHaloOp>(ctx);
+
+ (void)mlir::applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patterns));
+ }
+};
+
+} // namespace
+
+// Create a pass that convert Mesh to MPI
+std::unique_ptr<::mlir::OperationPass<void>> createConvertMeshToMPIPass() {
+ return std::make_unique<ConvertMeshToMPIPass>();
+}
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index c35020b4c20ccc..f25bbbf8e274b6 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -730,6 +730,25 @@ void ProcessLinearIndexOp::getAsmResultNames(
setNameFn(getResult(), "proc_linear_idx");
}
+//===----------------------------------------------------------------------===//
+// mesh.neighbors_linear_indices op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return success();
+}
+
+void NeighborsLinearIndicesOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getNeighborDown(), "down_linear_idx");
+ setNameFn(getNeighborUp(), "up_linear_idx");
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
new file mode 100644
index 00000000000000..5f563364272d96
--- /dev/null
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -0,0 +1,173 @@
+// RUN: mlir-opt %s -convert-mesh-to-mpi | FileCheck %s
+
+// CHECK: mesh.mesh @mesh0
+mesh.mesh @mesh0(shape = 2x2x4)
+
+// CHECK-LABEL: func @update_halo_1d_first
+func.func @update_halo_1d_first(
+ // CHECK-SAME: [[varg0:%.*]]: memref<12x12xi8>
+ %arg0 : memref<12x12xi8>) {
+ // CHECK-NEXT: [[vc9:%.*]] = arith.constant 9 : index
+ // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+ // CHECK-NEXT: [[vproc_linear_idx:%.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ // CHECK-NEXT: [[vdown_linear_idx:%.*]], [[vup_linear_idx:%.*]] = mesh.neighbors_linear_indices on @mesh0[[[vproc_linear_idx]]#0, [[vproc_linear_idx]]#1, [[vproc_linear_idx]]#2] split_axes = [0] : index, index
+ // CHECK-NEXT: [[v0:%.*]] = arith.index_cast [[vdown_linear_idx]] : index to i32
+ // CHECK-NEXT: [[v1:%.*]] = arith.index_cast [[vup_linear_idx]] : index to i32
+ // CHECK-NEXT: [[v2:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[v3:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x12xi8>
+ // CHECK-NEXT: scf.if [[v3]] {
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x12xi8, strided<[12, 1], offset: ?>> to memref<2x12xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[v0]]) : memref<2x12xi8>, i32, i32
+ // CHECK-NEXT: }
+ // CHECK-NEXT: scf.if [[v2]] {
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[v1]]) : memref<2x12xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [2, 12] [1, 1] : memref<12x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview]] : memref<2x12xi8> to memref<2x12xi8, strided<[12, 1]>>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x12xi8>
+ // CHECK-NEXT: [[v4:%.*]] = arith.cmpi sge, [[v0]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[v5:%.*]] = arith.cmpi sge, [[v1]], [[vc0_i32]] : i32
+ // CHECK-NEXT: [[valloc_0:%.*]] = memref.alloc() : memref<3x12xi8>
+ // CHECK-NEXT: scf.if [[v5]] {
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][0, 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1]>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc_0]] : memref<3x12xi8, strided<[12, 1]>> to memref<3x12xi8>
+ // CHECK-NEXT: mpi.send([[valloc_0]], [[vc91_i32]], [[v1]]) : memref<3x12xi8>, i32, i32
+ // CHECK-NEXT: }
+ // CHECK-NEXT: scf.if [[v4]] {
+ // CHECK-NEXT: mpi.recv([[valloc_0]], [[vc91_i32]], [[v0]]) : memref<3x12xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc9]], 0] [3, 12] [1, 1] : memref<12x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: memref.copy [[valloc_0]], [[vsubview]] : memref<3x12xi8> to memref<3x12xi8, strided<[12, 1], offset: ?>>
+ // CHECK-NEXT: }
+ // CHECK-NEXT: memref.dealloc [[valloc_0]] : memref<3x12xi8>
+ // CHECK-NEXT: return
+ mesh.update_halo %arg0 on @mesh0 split_axes = [[0]]
...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/104566
More information about the Mlir-commits
mailing list