[Mlir-commits] [mlir] More on MeshToMPI (PR #129048)
Frank Schlimbach
llvmlistbot at llvm.org
Thu Feb 27 05:09:56 PST 2025
https://github.com/fschlimb created https://github.com/llvm/llvm-project/pull/129048
- do not create MPI operations if no halo exchange is needed on a given boundary
- allow returning sharding information by return `!mesh.sharding` (gets converted into a tuple of tensors)
- lowering `mesh.shard_shape` including fixes to the operation itself
- global symbol `static_mpi_rank` replaced by an DLTI attribute (now aligned with MPIToLLVM)
- smaller fixes and some minor cleanup
>From 8dabd5755d095bfa1b231375f9e1907f4e0c3bb5 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 13 Dec 2024 16:24:43 +0100
Subject: [PATCH 1/8] handling empty halos when lowering mesh.update_halos
---
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 10 ++++++++--
1 file changed, 8 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 48b3764d520c2..e36498b72009c 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -236,6 +236,14 @@ struct ConvertUpdateHaloOp
// local data. Because subviews and halos can have mixed dynamic and static
// shapes, OpFoldResults are used whenever possible.
+ auto haloSizes =
+ getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter);
+ if (haloSizes.empty()) {
+ // no halos -> nothing to do
+ rewriter.replaceOp(op, op.getDestination());
+ return mlir::success();
+ }
+
SymbolTableCollection symbolTableCollection;
auto loc = op.getLoc();
@@ -262,8 +270,6 @@ struct ConvertUpdateHaloOp
auto opSplitAxes = op.getSplitAxes().getAxes();
auto mesh = op.getMesh();
auto meshOp = getMesh(op, symbolTableCollection);
- auto haloSizes =
- getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter);
// subviews need Index values
for (auto &sz : haloSizes) {
if (auto value = dyn_cast<Value>(sz)) {
>From 29d7fb03b08c03d82d171bb286b7cf6ca688d0d4 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 19 Dec 2024 13:18:07 +0100
Subject: [PATCH 2/8] fixing send/recv direction for halo updates; don't create
send/recv for empty halos
---
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 23 ++++++++++++++----
.../MeshToMPI/convert-mesh-to-mpi.mlir | 24 +++++++++++++++++++
2 files changed, 43 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index e36498b72009c..9fa718f1f1acb 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -362,8 +362,8 @@ struct ConvertUpdateHaloOp
: haloSizes[currHaloDim * 2];
// Check if we need to send and/or receive
// Processes on the mesh borders have only one neighbor
- auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
- auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+ auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+ auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
auto hasFrom = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, from, zero);
auto hasTo = rewriter.create<arith::CmpIOp>(
@@ -396,8 +396,23 @@ struct ConvertUpdateHaloOp
offsets[dim] = orgOffset;
};
- genSendRecv(false);
- genSendRecv(true);
+ auto get_i32val = [&](OpFoldResult &v) {
+ return v.is<Value>()
+ ? v.get<Value>()
+ : rewriter.create<::mlir::arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(
+ cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+ };
+
+ for (int i=0; i<2; ++i) {
+ Value haloSz = get_i32val(haloSizes[currHaloDim * 2 + i]);
+ auto hasSize = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, haloSz, zero);
+ rewriter.create<scf::IfOp>(
+ loc, hasSize, [&](OpBuilder &builder, Location loc) {
+ genSendRecv(i > 0);
+ builder.create<scf::YieldOp>(loc);
+ });
+ }
// the shape for lower dims include higher dims' halos
dimSizes[dim] = shape[dim];
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index c1aef97438bd5..04832ce11b7b3 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -102,6 +102,30 @@ func.func @update_halo_1d_first(
return %res : memref<120x120x120xi8>
}
+// -----
+mesh.mesh @mesh0(shape = 4)
+memref.global "public" constant @static_mpi_rank : memref<index> = dense<1>
+// CHECK-LABEL: func @update_halo_1d_with_zero
+func.func @update_halo_1d_with_zero (
+ // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
+ %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
+ // CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+ // CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8
+ // CHECK-SAME: to memref<2x120x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8>
+ %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
+ // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
+ return %res : memref<120x120x120xi8>
+}
+
// -----
mesh.mesh @mesh0(shape = 3x4x5)
memref.global constant @static_mpi_rank : memref<index> = dense<24>
>From d9bd663f5a0aa2b23510b32ac615a0d008e37c52 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 20 Jan 2025 13:57:44 +0100
Subject: [PATCH 3/8] adding type conversion of Sharding to 3 tensors
---
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 292 ++++++++++++++++--
.../MeshToMPI/convert-mesh-to-mpi.mlir | 68 ++++
2 files changed, 328 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 9fa718f1f1acb..91f235673f817 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -14,8 +14,12 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -25,6 +29,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "mesh-to-mpi"
@@ -72,13 +77,181 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
return linearIndex;
}
+// replace GetShardingOp with related ShardingOp
+struct ConvertGetShardingOp
+ : public mlir::OpConversionPattern<mlir::mesh::GetShardingOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mlir::mesh::GetShardingOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
+ if (!shardOp) {
+ return mlir::failure();
+ }
+ auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
+ if (!shardingOp) {
+ return mlir::failure();
+ }
+
+ rewriter.replaceOp(op, shardingOp.getResult());
+ return mlir::success();
+ }
+};
+
+struct ConvertShardingOp
+ : public mlir::OpConversionPattern<mlir::mesh::ShardingOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mlir::mesh::ShardingOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ auto splitAxes = op.getSplitAxes().getAxes();
+ int64_t maxNAxes = 0;
+ for (auto axes : splitAxes) {
+ maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
+ }
+
+ // To hold the split axes, create empty 2d tensor with shape
+ // {splitAxes.size(), max-size-of-split-groups}.
+ // Set trailing elements for smaller split-groups to -1.
+ auto loc = op.getLoc();
+ auto i16 = rewriter.getI16Type();
+ auto i64 = rewriter.getI64Type();
+ int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
+ Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
+ auto attr = IntegerAttr::get(i16, 0xffff);
+ auto fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
+ resSplitAxes =
+ rewriter
+ .create<linalg::FillOp>(loc, fillValue.getResult(), resSplitAxes)
+ .getResult(0);
+
+ // explicitly write values into tensor row by row
+ int64_t strides[] = {1, 1};
+ int64_t nSplits = 0;
+ ValueRange empty = {};
+ for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+ int64_t size = axes.size();
+ if (size > 0) {
+ ++nSplits;
+ }
+ int64_t offs[] = {(int64_t)i, 0};
+ int64_t sizes[] = {1, size};
+ auto tensorType = RankedTensorType::get({size}, i16);
+ auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
+ auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
+ resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
+ loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
+ }
+
+ // convert vec of OpFoldResults (ints) into vec of Values
+ auto getMixedAsValues = [&](llvm::ArrayRef<int64_t> statics,
+ ValueRange dynamics) {
+ SmallVector<Value> values;
+ auto dyn = dynamics.begin();
+ for (auto s : statics) {
+ values.emplace_back(
+ ShapedType::isDynamic(s)
+ ? *(dyn++)
+ : rewriter
+ .create<arith::ConstantOp>(loc, i64,
+ rewriter.getI64IntegerAttr(s))
+ .getResult());
+ }
+ return values;
+ };
+
+ // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
+ // Store the halo sizes in the tensor.
+ auto haloSizes = getMixedAsValues(adaptor.getStaticHaloSizes(),
+ adaptor.getDynamicHaloSizes());
+ auto type = RankedTensorType::get({nSplits, 2}, i64);
+ Value resHaloSizes =
+ haloSizes.empty()
+ ? rewriter
+ .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
+ i64)
+ .getResult()
+ : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
+ .getResult();
+
+ // To hold sharded dims offsets, create Tensor with shape {nSplits,
+ // maxSplitSize+1}. Store the offsets in the tensor but set trailing
+ // elements for smaller split-groups to -1. Computing the max size of the
+ // split groups needs using collectiveProcessGroupSize (which needs the
+ // MeshOp)
+ Value resOffsets;
+ if (adaptor.getStaticShardedDimsOffsets().empty()) {
+ resOffsets = rewriter.create<tensor::EmptyOp>(
+ loc, std::array<int64_t, 2>{0, 0}, i64);
+ } else {
+ SymbolTableCollection symbolTableCollection;
+ auto meshOp = getMesh(op, symbolTableCollection);
+ auto maxSplitSize = 0;
+ for (auto axes : splitAxes) {
+ auto splitSize =
+ collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ assert(splitSize != ShapedType::kDynamic);
+ maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
+ }
+ assert(maxSplitSize);
+ ++maxSplitSize; // add one for the total size
+
+ resOffsets = rewriter.create<tensor::EmptyOp>(
+ loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
+ auto zero =
+ rewriter
+ .create<arith::ConstantOp>(
+ loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic))
+ .getResult();
+ resOffsets =
+ rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
+ auto offsets = getMixedAsValues(adaptor.getStaticShardedDimsOffsets(),
+ adaptor.getDynamicShardedDimsOffsets());
+ int64_t curr = 0;
+ for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+ auto splitSize =
+ collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
+ ++splitSize; // add one for the total size
+ ArrayRef<Value> values(&offsets[curr], splitSize);
+ auto vals = rewriter.create<tensor::FromElementsOp>(loc, values);
+ int64_t offs[] = {(int64_t)i, 0};
+ int64_t sizes[] = {1, splitSize};
+ resOffsets = rewriter.create<tensor::InsertSliceOp>(
+ loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
+ curr += splitSize;
+ }
+ }
+
+ // return a tuple of tensors as defined by type converter
+ SmallVector<Type> resTypes;
+ if (failed(getTypeConverter()->convertType(op.getResult().getType(),
+ resTypes))) {
+ return mlir::failure();
+ }
+ resSplitAxes =
+ rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
+ resHaloSizes =
+ rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
+ resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets);
+
+ rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+ op, TupleType::get(op.getContext(), resTypes),
+ ValueRange{resSplitAxes, resHaloSizes, resOffsets});
+
+ return mlir::success();
+ }
+};
+
struct ConvertProcessMultiIndexOp
- : public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public mlir::OpConversionPattern<mlir::mesh::ProcessMultiIndexOp> {
+ using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult
- matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op,
- mlir::PatternRewriter &rewriter) const override {
+ matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
// Currently converts its linear index to a multi-dimensional index.
@@ -100,7 +273,7 @@ struct ConvertProcessMultiIndexOp
auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
// optionally extract subset of mesh axes
- auto axes = op.getAxes();
+ auto axes = adaptor.getAxes();
if (!axes.empty()) {
SmallVector<Value> subIndex;
for (auto axis : axes) {
@@ -115,12 +288,12 @@ struct ConvertProcessMultiIndexOp
};
struct ConvertProcessLinearIndexOp
- : public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public mlir::OpConversionPattern<mlir::mesh::ProcessLinearIndexOp> {
+ using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult
- matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
- mlir::PatternRewriter &rewriter) const override {
+ matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
// Finds a global named "static_mpi_rank" it will use that splat value.
// Otherwise it defaults to mpi.comm_rank.
@@ -149,18 +322,18 @@ struct ConvertProcessLinearIndexOp
};
struct ConvertNeighborsLinearIndicesOp
- : public mlir::OpRewritePattern<mlir::mesh::NeighborsLinearIndicesOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public mlir::OpConversionPattern<mlir::mesh::NeighborsLinearIndicesOp> {
+ using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult
- matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op,
- mlir::PatternRewriter &rewriter) const override {
+ matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
// Computes the neighbors indices along a split axis by simply
// adding/subtracting 1 to the current index in that dimension.
// Assigns -1 if neighbor is out of bounds.
- auto axes = op.getSplitAxes();
+ auto axes = adaptor.getSplitAxes();
// For now only single axis sharding is supported
if (axes.size() != 1) {
return mlir::failure();
@@ -169,7 +342,7 @@ struct ConvertNeighborsLinearIndicesOp
auto loc = op.getLoc();
SymbolTableCollection symbolTableCollection;
auto meshOp = getMesh(op, symbolTableCollection);
- auto mIdx = op.getDevice();
+ auto mIdx = adaptor.getDevice();
auto orgIdx = mIdx[axes[0]];
SmallVector<Value> dims;
llvm::transform(
@@ -217,12 +390,12 @@ struct ConvertNeighborsLinearIndicesOp
};
struct ConvertUpdateHaloOp
- : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public mlir::OpConversionPattern<mlir::mesh::UpdateHaloOp> {
+ using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult
- matchAndRewrite(mlir::mesh::UpdateHaloOp op,
- mlir::PatternRewriter &rewriter) const override {
+ matchAndRewrite(mlir::mesh::UpdateHaloOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
// The input/output memref is assumed to be in C memory order.
// Halos are exchanged as 2 blocks per dimension (one for each side: down
@@ -236,11 +409,11 @@ struct ConvertUpdateHaloOp
// local data. Because subviews and halos can have mixed dynamic and static
// shapes, OpFoldResults are used whenever possible.
- auto haloSizes =
- getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter);
+ auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
+ adaptor.getHaloSizes(), rewriter);
if (haloSizes.empty()) {
// no halos -> nothing to do
- rewriter.replaceOp(op, op.getDestination());
+ rewriter.replaceOp(op, adaptor.getDestination());
return mlir::success();
}
@@ -256,7 +429,7 @@ struct ConvertUpdateHaloOp
cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
};
- auto dest = op.getDestination();
+ auto dest = adaptor.getDestination();
auto dstShape = cast<ShapedType>(dest.getType()).getShape();
Value array = dest;
if (isa<RankedTensorType>(array.getType())) {
@@ -267,8 +440,8 @@ struct ConvertUpdateHaloOp
.getResult();
}
auto rank = cast<ShapedType>(array.getType()).getRank();
- auto opSplitAxes = op.getSplitAxes().getAxes();
- auto mesh = op.getMesh();
+ auto opSplitAxes = adaptor.getSplitAxes().getAxes();
+ auto mesh = adaptor.getMesh();
auto meshOp = getMesh(op, symbolTableCollection);
// subviews need Index values
for (auto &sz : haloSizes) {
@@ -440,14 +613,69 @@ struct ConvertMeshToMPIPass
/// Run the dialect converter on the module.
void runOnOperation() override {
- auto *ctx = &getContext();
- mlir::RewritePatternSet patterns(ctx);
-
- patterns.insert<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
- ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
- ctx);
+ auto *ctxt = &getContext();
+ mlir::RewritePatternSet patterns(ctxt);
+ ConversionTarget target(getContext());
+ mlir::TypeConverter typeConverter;
+
+ typeConverter.addConversion([](Type type) { return type; });
+ // convert mesh::ShardingType to a tuple of RankedTensorTypes
+ typeConverter.addConversion(
+ [](ShardingType type,
+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+ auto i16 = IntegerType::get(type.getContext(), 16);
+ auto i64 = IntegerType::get(type.getContext(), 64);
+ std::array<int64_t, 2> shp{ShapedType::kDynamic,
+ ShapedType::kDynamic};
+ results.emplace_back(RankedTensorType::get(shp, i16));
+ results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
+ results.emplace_back(RankedTensorType::get(shp, i64));
+ return mlir::success();
+ });
+ // To extract components a UnrealizedConversionCastOp is expected to define
+ // the input
+ typeConverter.addTargetMaterialization(
+ [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+ Location loc) {
+ if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType())) {
+ return SmallVector<Value>();
+ }
+ auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
+ if (!castOp) {
+ return SmallVector<Value>();
+ }
+ SmallVector<Value> results;
+ for (auto oprnd : castOp.getInputs()) {
+ if (!isa<RankedTensorType>(oprnd.getType())) {
+ return SmallVector<Value>();
+ }
+ results.emplace_back(oprnd);
+ }
+ return results;
+ });
- (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
+ target.addIllegalDialect<mesh::MeshDialect>();
+ target.addLegalOp<mesh::MeshOp>();
+ target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
+ arith::ArithDialect, tensor::TensorDialect,
+ bufferization::BufferizationDialect,
+ linalg::LinalgDialect, memref::MemRefDialect>();
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getFunctionType());
+ });
+ target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
+ [&](Operation *op) { return typeConverter.isLegal(op); });
+
+ patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
+ ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp,
+ ConvertShardingOp, ConvertGetShardingOp>(typeConverter, ctxt);
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+ patterns, typeConverter);
+ populateCallOpTypeConversionPattern(patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+
+ (void)mlir::applyPartialConversion(getOperation(), target,
+ std::move(patterns));
}
};
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 04832ce11b7b3..4ed6a2ffcd683 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -230,3 +230,71 @@ func.func @update_halo_3d_tensor(
// CHECK: return [[v1]] : tensor<120x120x120xi8>
return %res : tensor<120x120x120xi8>
}
+
+// -----
+mesh.mesh @mesh0(shape = 2x2x4)
+// CHECK-LABEL: func.func @return_sharding(
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
+func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sharding) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] : !mesh.sharding
+ // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
+ // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+ // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+ // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+ // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+ // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+ // CHECK-NEXT: [[vinserted_slice_1:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+ // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+ // CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<0x0xi64>
+ // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_1]] : tensor<2x2xi16> to tensor<?x?xi16>
+ // CHECK-NEXT: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+ // CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor<?x?xi64>
+ // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+ return %arg0, %sharding : tensor<2x4xf32>, !mesh.sharding
+}
+
+// CHECK-LABEL: func.func @return_sharding_halos(
+// CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
+func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !mesh.sharding) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !mesh.sharding
+ // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
+ // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
+ // CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+ // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+ // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+ // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+ // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+ // CHECK-NEXT: [[vinserted_slice_2:%.*]] = tensor.insert_slice [[vcst_0]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+ // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+ // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_2]] : tensor<2x2xi16> to tensor<?x?xi16>
+ // CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor<?x?xi64>
+ // CHECK-NEXT: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+ // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+ return %arg0, %sharding : tensor<6x8xf32>, !mesh.sharding
+}
+
+// CHECK-LABEL: func.func @return_sharding_offs(
+// CHECK-SAME: [[varg0:%.*]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
+func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !mesh.sharding) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !mesh.sharding
+ // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
+ // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
+ // CHECK-NEXT: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
+ // CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
+ // CHECK-NEXT: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+ // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+ // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+ // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+ // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+ // CHECK-NEXT: [[vinserted_slice_3:%.*]] = tensor.insert_slice [[vcst_1]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+ // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+ // CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<2x5xi64>
+ // CHECK-NEXT: [[v4:%.*]] = linalg.fill ins([[vcm9223372036854775808_i64]] : i64) outs([[v3]] : tensor<2x5xi64>) -> tensor<2x5xi64>
+ // CHECK-NEXT: [[vinserted_slice_4:%.*]] = tensor.insert_slice [[vcst_0]] into [[v4]][0, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
+ // CHECK-NEXT: [[vinserted_slice_5:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice_4]][1, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
+ // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_3]] : tensor<2x2xi16> to tensor<?x?xi16>
+ // CHECK-NEXT: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+ // CHECK-NEXT: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor<?x?xi64>
+ // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+ return %arg0, %sharding : tensor<?x?xf32>, !mesh.sharding
+}
>From d4dbd746e0b2f8fc728c00c0053923146d9c74f1 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 24 Jan 2025 11:29:09 +0100
Subject: [PATCH 4/8] fixingconvert-mesh-to-mpi.mlir tests
---
.../MeshToMPI/convert-mesh-to-mpi.mlir | 118 +++++++++---------
1 file changed, 59 insertions(+), 59 deletions(-)
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 4ed6a2ffcd683..6d865ddaf7ce5 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -133,7 +133,7 @@ memref.global constant @static_mpi_rank : memref<index> = dense<24>
func.func @update_halo_3d(
// CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
%arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
- // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32
+ // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
// CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
// CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
// CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
@@ -141,39 +141,39 @@ func.func @update_halo_3d(
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
- // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
// CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
- // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
- // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
- // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
- // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+ // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+ // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
// CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
- // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
- // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
- // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
- // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
- // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
- // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
- // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
- // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
- // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
- // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
- // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8>
- // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
- // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
- // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
+ // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
+ // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
+ // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
+ // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+ // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+ // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+ // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+ // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+ // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
+ // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
+ // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
+ // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
%res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
// CHECK: return [[varg0]] : memref<120x120x120xi8>
return %res : memref<120x120x120xi8>
@@ -183,7 +183,7 @@ func.func @update_halo_3d(
func.func @update_halo_3d_tensor(
// CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
%arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
- // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32
+ // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
// CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
// CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
// CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
@@ -192,40 +192,40 @@ func.func @update_halo_3d_tensor(
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
// CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
- // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
// CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
// CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
// CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
// CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
- // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
- // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
- // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
- // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+ // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+ // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
// CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
- // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
- // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
- // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
- // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
- // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
- // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
- // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
- // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
- // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
- // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
- // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8>
- // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
- // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
- // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
- // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8>
+ // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
+ // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
+ // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
+ // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+ // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+ // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+ // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+ // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+ // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
+ // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
+ // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
+ // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
+ // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
%res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
// CHECK: return [[v1]] : tensor<120x120x120xi8>
return %res : tensor<120x120x120xi8>
>From 806ffd4207acf157f6a6def219b3e06da6e003c7 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Mon, 24 Feb 2025 12:50:23 +0100
Subject: [PATCH 5/8] clang-format
---
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 22 +++++++++++----------
1 file changed, 12 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 91f235673f817..21f1ba411dafb 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -573,18 +573,20 @@ struct ConvertUpdateHaloOp
return v.is<Value>()
? v.get<Value>()
: rewriter.create<::mlir::arith::ConstantOp>(
- loc, rewriter.getI32IntegerAttr(
- cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+ loc,
+ rewriter.getI32IntegerAttr(
+ cast<IntegerAttr>(v.get<Attribute>()).getInt()));
};
- for (int i=0; i<2; ++i) {
- Value haloSz = get_i32val(haloSizes[currHaloDim * 2 + i]);
- auto hasSize = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, haloSz, zero);
- rewriter.create<scf::IfOp>(
- loc, hasSize, [&](OpBuilder &builder, Location loc) {
- genSendRecv(i > 0);
- builder.create<scf::YieldOp>(loc);
- });
+ for (int i = 0; i < 2; ++i) {
+ Value haloSz = get_i32val(haloSizes[currHaloDim * 2 + i]);
+ auto hasSize = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sgt, haloSz, zero);
+ rewriter.create<scf::IfOp>(loc, hasSize,
+ [&](OpBuilder &builder, Location loc) {
+ genSendRecv(i > 0);
+ builder.create<scf::YieldOp>(loc);
+ });
}
// the shape for lower dims include higher dims' halos
>From 35f31d840023cc9fdf4050f0604586141794d848 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 26 Feb 2025 12:15:58 +0100
Subject: [PATCH 6/8] fixes to and lowering of mesh.shard_shape
---
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 22 ++-
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 176 +++++++++++++++---
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 15 +-
.../Extensions/MeshShardingExtensions.cpp | 19 +-
.../MeshToMPI/convert-shardshape-to-mpi.mlir | 75 ++++++++
mlir/test/Dialect/Mesh/ops.mlir | 10 +-
.../test/Dialect/Tensor/mesh-spmdization.mlir | 11 +-
7 files changed, 275 insertions(+), 53 deletions(-)
create mode 100644 mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 031e6f63bcb42..f59c4c4c67517 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -345,24 +345,32 @@ def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
}];
}
-def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
- let summary = "Get the shard shape of a given process/device.";
+def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [
+ Pure, AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
+ let summary = "Get the shard shape for a given process/device.";
let description = [{
- The device/process id is a linearized id of the device/process in the mesh.
+ The device/process id is a multi-index of the device/process in the mesh.
This operation might be used during spmdization when the shard shape depends
on (non-constant) values used in `mesh.sharding`.
}];
let arguments = (ins
- DenseI64ArrayAttr:$shape,
+ DenseI64ArrayAttr:$dims,
+ Variadic<Index>:$dims_dynamic,
Mesh_Sharding:$sharding,
- Index:$device
+ DenseI64ArrayAttr:$device,
+ Variadic<Index>:$device_dynamic
);
let results = (outs Variadic<Index>:$result);
let assemblyFormat = [{
- custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result)
+ `dims` `=` custom<DynamicIndexList>($dims_dynamic, $dims)
+ `sharding` `=` $sharding
+ `device` `=` custom<DynamicIndexList>($device_dynamic, $device)
+ attr-dict `:` type(results)
}];
let builders = [
- OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$sharding, "Value":$device)>
+ OpBuilder<(ins "ArrayRef<int64_t>":$dims, "ArrayRef<Value>":$dims_dyn, "Value":$sharding, "ValueRange":$device)>
];
}
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 21f1ba411dafb..b94181dd4f061 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -44,6 +44,29 @@ using namespace mlir;
using namespace mlir::mesh;
namespace {
+// convert vec of OpFoldResults (ints) into vec of Values
+static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
+ llvm::ArrayRef<int64_t> statics,
+ ValueRange dynamics,
+ Type type = Type()) {
+ SmallVector<Value> values;
+ auto dyn = dynamics.begin();
+ Type i64 = b.getI64Type();
+ if (!type)
+ type = i64;
+ assert(i64 == type || b.getIndexType() == type);
+ for (auto s : statics) {
+ values.emplace_back(
+ ShapedType::isDynamic(s)
+ ? *(dyn++)
+ : b.create<arith::ConstantOp>(loc, type,
+ i64 == type ? b.getI64IntegerAttr(s)
+ : b.getIndexAttr(s))
+ .getResult());
+ }
+ return values;
+};
+
// Create operations converting a linear index to a multi-dimensional index
static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
Value linearIndex,
@@ -145,27 +168,11 @@ struct ConvertShardingOp
loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
}
- // convert vec of OpFoldResults (ints) into vec of Values
- auto getMixedAsValues = [&](llvm::ArrayRef<int64_t> statics,
- ValueRange dynamics) {
- SmallVector<Value> values;
- auto dyn = dynamics.begin();
- for (auto s : statics) {
- values.emplace_back(
- ShapedType::isDynamic(s)
- ? *(dyn++)
- : rewriter
- .create<arith::ConstantOp>(loc, i64,
- rewriter.getI64IntegerAttr(s))
- .getResult());
- }
- return values;
- };
-
// To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
// Store the halo sizes in the tensor.
- auto haloSizes = getMixedAsValues(adaptor.getStaticHaloSizes(),
- adaptor.getDynamicHaloSizes());
+ auto haloSizes =
+ getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
+ adaptor.getDynamicHaloSizes());
auto type = RankedTensorType::get({nSplits, 2}, i64);
Value resHaloSizes =
haloSizes.empty()
@@ -207,8 +214,9 @@ struct ConvertShardingOp
.getResult();
resOffsets =
rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
- auto offsets = getMixedAsValues(adaptor.getStaticShardedDimsOffsets(),
- adaptor.getDynamicShardedDimsOffsets());
+ auto offsets =
+ getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
+ adaptor.getDynamicShardedDimsOffsets());
int64_t curr = 0;
for (auto [i, axes] : llvm::enumerate(splitAxes)) {
auto splitSize =
@@ -389,6 +397,123 @@ struct ConvertNeighborsLinearIndicesOp
}
};
+struct ConvertShardShapeOp
+ : public mlir::OpConversionPattern<mlir::mesh::ShardShapeOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mlir::mesh::ShardShapeOp op, OneToNOpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
+ if (!sharding) {
+ return op->emitError()
+ << "Expected SharingOp as defining op for sharding"
+ << " but found " << adaptor.getSharding()[0].getDefiningOp();
+ }
+
+ auto loc = op.getLoc();
+ Type index = rewriter.getIndexType();
+
+ // Fill a vector with the dynamic dims
+ SmallVector<Value> dynDims, dynDevice;
+ for (auto dim : adaptor.getDimsDynamic()) {
+ // type conversion should be 1:1 for ints
+ assert(dim.size() == 1);
+ dynDims.emplace_back(dim[0]);
+ }
+ // same for device
+ for (auto device : adaptor.getDeviceDynamic()) {
+ assert(device.size() == 1);
+ dynDevice.emplace_back(device[0]);
+ }
+
+ // To keep the code simple, we convert dims/device to values when they are
+ // attributes
+ auto shape = getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
+ auto multiIdx =
+ getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
+
+ SymbolTableCollection symbolTableCollection;
+ auto meshOp = getMesh(sharding, symbolTableCollection);
+ // For now we only support static mesh shapes
+ if (ShapedType::isDynamicShape(meshOp.getShape())) {
+ return mlir::failure();
+ }
+
+ auto splitAxes = sharding.getSplitAxes().getAxes();
+ Value shardedDimsOffs;
+ {
+ auto tmp = getMixedAsValues(
+ rewriter, loc, sharding.getStaticShardedDimsOffsets(),
+ sharding.getDynamicShardedDimsOffsets(), index);
+ if (!tmp.empty())
+ shardedDimsOffs =
+ rewriter
+ .create<tensor::FromElementsOp>(
+ loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
+ tmp)
+ .getResult();
+ }
+
+ int64_t pos = 0; // position in shardedDimsOffs
+ SmallVector<Value> shardShape;
+ Value zero =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index));
+ Value one =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index));
+ for (auto [i, dim] : llvm::enumerate(shape)) {
+ if (i < splitAxes.size()) { // trailing dimensions might not be annotated
+ auto axes = splitAxes[i];
+ if (!axes.empty()) {
+ Value posVal = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIndexAttr(pos));
+ Value idx = multiIdx[axes[0]];
+ auto _numShards =
+ collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ if (shardedDimsOffs) {
+ // If sharded dims offsets are provided, use them to compute the
+ // sharded shape.
+ if (axes.size() > 1)
+ return op->emitError() << "Only single axis sharding is "
+ "supported for each dimension.";
+ idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
+ Value off =
+ rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx)
+ .getResult();
+ idx = rewriter.create<arith::AddIOp>(loc, idx, one);
+ Value nextOff =
+ rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx)
+ .getResult();
+ Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
+ shardShape.emplace_back(sz);
+ } else {
+ auto numShards = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIndexAttr(_numShards));
+ // compute shard dim size by distributing odd elements to trailing
+ // shards sz = dim / numShards + (idx >= (numShards - (dim %
+ // numShards)) ? one : zero)
+ Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShards);
+ Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShards);
+ sz1 = rewriter.create<arith::SubIOp>(loc, numShards, sz1);
+ auto cond = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, idx, sz1);
+ Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero);
+ sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
+ shardShape.emplace_back(sz);
+ }
+ pos += _numShards + 1; // add one for the total size
+ } // else no sharding in this dimension
+ }
+ // If no size was added -> no sharding in this dimension.
+ if (shardShape.size() <= i)
+ shardShape.emplace_back(dim);
+ }
+ assert(shardShape.size() == shape.size());
+ rewriter.replaceOp(op, shardShape);
+ return mlir::success();
+ }
+};
+
struct ConvertUpdateHaloOp
: public mlir::OpConversionPattern<mlir::mesh::UpdateHaloOp> {
using OpConversionPattern::OpConversionPattern;
@@ -570,12 +695,12 @@ struct ConvertUpdateHaloOp
};
auto get_i32val = [&](OpFoldResult &v) {
- return v.is<Value>()
- ? v.get<Value>()
+ return isa<Value>(v)
+ ? cast<Value>(v)
: rewriter.create<::mlir::arith::ConstantOp>(
loc,
rewriter.getI32IntegerAttr(
- cast<IntegerAttr>(v.get<Attribute>()).getInt()));
+ cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
};
for (int i = 0; i < 2; ++i) {
@@ -670,7 +795,8 @@ struct ConvertMeshToMPIPass
patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp,
- ConvertShardingOp, ConvertGetShardingOp>(typeConverter, ctxt);
+ ConvertShardingOp, ConvertGetShardingOp, ConvertShardShapeOp>(
+ typeConverter, ctxt);
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
populateCallOpTypeConversionPattern(patterns, typeConverter);
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 304ede195c762..3e9f86fde64f3 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -831,12 +831,19 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
// mesh.shard_shape
//===----------------------------------------------------------------------===//
+void ShardShapeOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult()[0], "shard_shape");
+}
+
void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState,
- ::llvm::ArrayRef<int64_t> shape,
- ::mlir::Value sharding, ::mlir::Value device) {
- SmallVector<mlir::Type> resType(shape.size(), odsBuilder.getIndexType());
- build(odsBuilder, odsState, resType, shape, sharding, device);
+ ::llvm::ArrayRef<int64_t> dims,
+ ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
+ ::mlir::ValueRange device) {
+ SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
+ build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
+ SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index b2acbf20b3fb9..b3d69eb5e1a23 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -50,10 +50,10 @@ struct CreatorOpShardingInterface
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
- auto shardType = cast<ShapedType>(mesh::shardType(
- op->getResult(0).getType(),
- mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable),
- resultShardings[0]));
+ auto mesh =
+ mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
+ auto shardType = cast<ShapedType>(
+ mesh::shardType(op->getResult(0).getType(), mesh, resultShardings[0]));
Operation *newOp = nullptr;
// if the sharding introduces a new dynamic dimension, we take it from
// the dynamic sharding info. For now bail out if it's not
@@ -66,18 +66,19 @@ struct CreatorOpShardingInterface
assert(oldType.getRank() == shardType.getRank());
int currOldOprndNum = -1;
mesh::ShardShapeOp shapeForDevice;
- Value device;
+ ValueRange device;
Operation *newSharding = nullptr;
for (auto i = 0; i < oldType.getRank(); ++i) {
if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
if (!newSharding) {
newSharding =
builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
- device = builder.create<mesh::ProcessLinearIndexOp>(
- op->getLoc(), resultShardings[0].getMesh());
+ device =
+ builder.create<mesh::ProcessMultiIndexOp>(op->getLoc(), mesh)
+ .getResults();
shapeForDevice = builder.create<mesh::ShardShapeOp>(
- op->getLoc(), oldType.getShape(), newSharding->getResult(0),
- device);
+ op->getLoc(), oldType.getShape(), spmdizedOperands,
+ newSharding->getResult(0), device);
}
newOperands.emplace_back(shapeForDevice.getResult()[i]);
} else if (oldType.isDynamicDim(i)) {
diff --git a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
new file mode 100644
index 0000000000000..1d5cedf57ec6c
--- /dev/null
+++ b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s --convert-mesh-to-mpi \
+// RUN: | sed 's/%retval, %rank = mpi.comm_rank : !mpi.retval, i32/%rank = arith.constant 24 : i32/g' \
+// RUN: | mlir-opt -canonicalize \
+// RUN: | FileCheck %s
+
+// Notice: The sed replacement above defines the linear index to be 24, which is multiindex [1, 0, 4]
+
+// CHECK: mesh.mesh @mesh0
+mesh.mesh @mesh0(shape = 3x4x5)
+
+// all shards are equal
+// CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) {
+func.func @shard_shape_equal() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %c9 = arith.constant 9 : index
+ %c12 = arith.constant 12 : index
+ // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+ %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// last shard in last dim gets an extra element
+// CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) {
+func.func @shard_shape_odd_1() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %c9 = arith.constant 9 : index
+ %c12 = arith.constant 12 : index
+ // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
+ // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+ %1:3 = mesh.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// all except first shard in second dim get an extra element
+// CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
+func.func @shard_shape_odd_2() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %c9 = arith.constant 9 : index
+ // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+ %1:3 = mesh.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// all except first shard in first dim get an extra element
+// CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
+func.func @shard_shape_odd_3() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
+ // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+ %1:3 = mesh.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// extract from sharded_dims_offsets
+// CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
+func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]]
+ sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %c9 = arith.constant 9 : index
+ %c12 = arith.constant 12 : index
+ // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+ // CHECK: [[vc2:%.*]] = arith.constant 2 : index
+ %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 43a75bf3d8040..3d133f2255772 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -157,10 +157,12 @@ func.func @mesh_shard_shape() {
%c3 = arith.constant 3 : index
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
%s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
- // CHECK-NEXT: mesh.shard_shape 8x? %[[S]] %[[C3]] : index, index
- %shp:2 = mesh.shard_shape 8x? %s %c3 : index, index
- // CHECK-NEXT: mesh.shard_shape 8x4 %[[S]] %[[C3]] : index, index
- %shp1:2 = mesh.shard_shape 8x4 %s %c3 : index, index
+ // CHECK-NEXT: mesh.shard_shape dims = [8, %[[C3]]
+ // CHECK-SAME: ] sharding = %[[S]] device = [%[[C3]]
+ // CHECK-SAME: ] : index, index
+ %shp:2 = mesh.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index
+ // CHECK-NEXT: mesh.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index
+ %shp1:2 = mesh.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index
return
}
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index 5443eea83aa2d..01cf5972177f4 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -10,8 +10,9 @@ func.func @tensor_empty_static_sharded_dims_offsets() -> () {
%sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
%sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
// CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
- // CHECK: %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
- // CHECK: %[[V0:.*]]:2 = mesh.shard_shape 8x16 %[[sharding]] %[[proc_linear_idx]] : index, index
+ // CHECK: %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index
+ // CHECK: %[[V0:.*]]:2 = mesh.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
+ // CHECK-SAME: ] : index, index
// CHECK: tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
return
@@ -24,8 +25,10 @@ func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
%sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
%sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
// CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
- // CHECK: %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
- // CHECK: %[[V0:.*]]:2 = mesh.shard_shape 8x? %[[sharding]] %[[proc_linear_idx]] : index, index
+ // CHECK: %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index
+ // CHECK: %[[V0:.*]]:2 = mesh.shard_shape dims = [8, %[[A0]]
+ // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
+ // CHECK-SAME: ] : index, index
// CHECK: tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
return
>From 1faf707c726f670cfadae43b489df00e233ec72b Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Feb 2025 11:58:17 +0100
Subject: [PATCH 7/8] cleanup & formatting
---
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 254 +++++++++-----------
1 file changed, 112 insertions(+), 142 deletions(-)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index b94181dd4f061..9d638100e637b 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -41,10 +41,10 @@ namespace mlir {
} // namespace mlir
using namespace mlir;
-using namespace mlir::mesh;
+using namespace mesh;
namespace {
-// convert vec of OpFoldResults (ints) into vec of Values
+/// Convert vec of OpFoldResults (ints) into vector of Values.
static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
llvm::ArrayRef<int64_t> statics,
ValueRange dynamics,
@@ -61,13 +61,12 @@ static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
? *(dyn++)
: b.create<arith::ConstantOp>(loc, type,
i64 == type ? b.getI64IntegerAttr(s)
- : b.getIndexAttr(s))
- .getResult());
+ : b.getIndexAttr(s)));
}
return values;
};
-// Create operations converting a linear index to a multi-dimensional index
+/// Create operations converting a linear index to a multi-dimensional index.
static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
Value linearIndex,
ValueRange dimensions) {
@@ -76,23 +75,22 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
for (int i = n - 1; i >= 0; --i) {
multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
- if (i > 0) {
+ if (i > 0)
linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
- }
}
return multiIndex;
}
-// Create operations converting a multi-dimensional index to a linear index
+/// Create operations converting a multi-dimensional index to a linear index.
Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
ValueRange dimensions) {
- auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
- auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
+ Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
for (int i = multiIndex.size() - 1; i >= 0; --i) {
- auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
+ Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
}
@@ -100,35 +98,34 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
return linearIndex;
}
-// replace GetShardingOp with related ShardingOp
-struct ConvertGetShardingOp
- : public mlir::OpConversionPattern<mlir::mesh::GetShardingOp> {
+/// Replace GetShardingOp with related/dependent ShardingOp.
+struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
using OpConversionPattern::OpConversionPattern;
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::GetShardingOp op, OpAdaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
- if (!shardOp) {
- return mlir::failure();
- }
+ if (!shardOp)
+ return failure();
auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
- if (!shardingOp) {
- return mlir::failure();
- }
+ if (!shardingOp)
+ return failure();
rewriter.replaceOp(op, shardingOp.getResult());
- return mlir::success();
+ return success();
}
};
-struct ConvertShardingOp
- : public mlir::OpConversionPattern<mlir::mesh::ShardingOp> {
+/// Convert a sharding op to a tuple of tensors of its components
+/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
+/// as defined by type converter.
+struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
using OpConversionPattern::OpConversionPattern;
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::ShardingOp op, OpAdaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto splitAxes = op.getSplitAxes().getAxes();
int64_t maxNAxes = 0;
for (auto axes : splitAxes) {
@@ -138,17 +135,15 @@ struct ConvertShardingOp
// To hold the split axes, create empty 2d tensor with shape
// {splitAxes.size(), max-size-of-split-groups}.
// Set trailing elements for smaller split-groups to -1.
- auto loc = op.getLoc();
+ Location loc = op.getLoc();
auto i16 = rewriter.getI16Type();
auto i64 = rewriter.getI64Type();
int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
auto attr = IntegerAttr::get(i16, 0xffff);
- auto fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
- resSplitAxes =
- rewriter
- .create<linalg::FillOp>(loc, fillValue.getResult(), resSplitAxes)
- .getResult(0);
+ Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
+ resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
+ .getResult(0);
// explicitly write values into tensor row by row
int64_t strides[] = {1, 1};
@@ -156,9 +151,8 @@ struct ConvertShardingOp
ValueRange empty = {};
for (auto [i, axes] : llvm::enumerate(splitAxes)) {
int64_t size = axes.size();
- if (size > 0) {
+ if (size > 0)
++nSplits;
- }
int64_t offs[] = {(int64_t)i, 0};
int64_t sizes[] = {1, size};
auto tensorType = RankedTensorType::get({size}, i16);
@@ -197,7 +191,7 @@ struct ConvertShardingOp
auto meshOp = getMesh(op, symbolTableCollection);
auto maxSplitSize = 0;
for (auto axes : splitAxes) {
- auto splitSize =
+ int64_t splitSize =
collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
assert(splitSize != ShapedType::kDynamic);
maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
@@ -207,11 +201,8 @@ struct ConvertShardingOp
resOffsets = rewriter.create<tensor::EmptyOp>(
loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
- auto zero =
- rewriter
- .create<arith::ConstantOp>(
- loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic))
- .getResult();
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
resOffsets =
rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
auto offsets =
@@ -219,12 +210,12 @@ struct ConvertShardingOp
adaptor.getDynamicShardedDimsOffsets());
int64_t curr = 0;
for (auto [i, axes] : llvm::enumerate(splitAxes)) {
- auto splitSize =
+ int64_t splitSize =
collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
++splitSize; // add one for the total size
ArrayRef<Value> values(&offsets[curr], splitSize);
- auto vals = rewriter.create<tensor::FromElementsOp>(loc, values);
+ Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
int64_t offs[] = {(int64_t)i, 0};
int64_t sizes[] = {1, splitSize};
resOffsets = rewriter.create<tensor::InsertSliceOp>(
@@ -236,9 +227,9 @@ struct ConvertShardingOp
// return a tuple of tensors as defined by type converter
SmallVector<Type> resTypes;
if (failed(getTypeConverter()->convertType(op.getResult().getType(),
- resTypes))) {
- return mlir::failure();
- }
+ resTypes)))
+ return failure();
+
resSplitAxes =
rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
resHaloSizes =
@@ -249,35 +240,33 @@ struct ConvertShardingOp
op, TupleType::get(op.getContext(), resTypes),
ValueRange{resSplitAxes, resHaloSizes, resOffsets});
- return mlir::success();
+ return success();
}
};
struct ConvertProcessMultiIndexOp
- : public mlir::OpConversionPattern<mlir::mesh::ProcessMultiIndexOp> {
+ : public OpConversionPattern<ProcessMultiIndexOp> {
using OpConversionPattern::OpConversionPattern;
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op, OpAdaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
// Currently converts its linear index to a multi-dimensional index.
SymbolTableCollection symbolTableCollection;
- auto loc = op.getLoc();
+ Location loc = op.getLoc();
auto meshOp = getMesh(op, symbolTableCollection);
// For now we only support static mesh shapes
- if (ShapedType::isDynamicShape(meshOp.getShape())) {
- return mlir::failure();
- }
+ if (ShapedType::isDynamicShape(meshOp.getShape()))
+ return failure();
SmallVector<Value> dims;
llvm::transform(
meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
});
- auto rank =
- rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp).getResult();
+ Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
// optionally extract subset of mesh axes
@@ -291,22 +280,22 @@ struct ConvertProcessMultiIndexOp
}
rewriter.replaceOp(op, mIdx);
- return mlir::success();
+ return success();
}
};
struct ConvertProcessLinearIndexOp
- : public mlir::OpConversionPattern<mlir::mesh::ProcessLinearIndexOp> {
+ : public OpConversionPattern<ProcessLinearIndexOp> {
using OpConversionPattern::OpConversionPattern;
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op, OpAdaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
// Finds a global named "static_mpi_rank" it will use that splat value.
// Otherwise it defaults to mpi.comm_rank.
- auto loc = op.getLoc();
+ Location loc = op.getLoc();
auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
op, rankOpName)) {
@@ -314,7 +303,7 @@ struct ConvertProcessLinearIndexOp
auto val = cast<DenseElementsAttr>(initTnsr).getSplatValue<int64_t>();
rewriter.replaceOp(op,
rewriter.create<arith::ConstantIndexOp>(loc, val));
- return mlir::success();
+ return success();
}
}
auto rank =
@@ -325,17 +314,17 @@ struct ConvertProcessLinearIndexOp
.getRank();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
rank);
- return mlir::success();
+ return success();
}
};
struct ConvertNeighborsLinearIndicesOp
- : public mlir::OpConversionPattern<mlir::mesh::NeighborsLinearIndicesOp> {
+ : public OpConversionPattern<NeighborsLinearIndicesOp> {
using OpConversionPattern::OpConversionPattern;
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op, OpAdaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
// Computes the neighbors indices along a split axis by simply
// adding/subtracting 1 to the current index in that dimension.
@@ -343,11 +332,10 @@ struct ConvertNeighborsLinearIndicesOp
auto axes = adaptor.getSplitAxes();
// For now only single axis sharding is supported
- if (axes.size() != 1) {
- return mlir::failure();
- }
+ if (axes.size() != 1)
+ return failure();
- auto loc = op.getLoc();
+ Location loc = op.getLoc();
SymbolTableCollection symbolTableCollection;
auto meshOp = getMesh(op, symbolTableCollection);
auto mIdx = adaptor.getDevice();
@@ -357,12 +345,12 @@ struct ConvertNeighborsLinearIndicesOp
meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
});
- auto dimSz = dims[axes[0]];
- auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1).getResult();
- auto minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1).getResult();
- auto atBorder = rewriter.create<arith::CmpIOp>(
+ Value dimSz = dims[axes[0]];
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1);
+ Value atBorder = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, orgIdx,
- rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult());
+ rewriter.create<arith::ConstantIndexOp>(loc, 0));
auto down = rewriter.create<scf::IfOp>(
loc, atBorder,
[&](OpBuilder &builder, Location loc) {
@@ -387,23 +375,21 @@ struct ConvertNeighborsLinearIndicesOp
[&](OpBuilder &builder, Location loc) {
SmallVector<Value> tmp = mIdx;
tmp[axes[0]] =
- rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one)
- .getResult();
+ rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
builder.create<scf::YieldOp>(
loc, multiToLinearIndex(loc, rewriter, tmp, dims));
});
rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
- return mlir::success();
+ return success();
}
};
-struct ConvertShardShapeOp
- : public mlir::OpConversionPattern<mlir::mesh::ShardShapeOp> {
+struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
using OpConversionPattern::OpConversionPattern;
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::ShardShapeOp op, OneToNOpAdaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
if (!sharding) {
return op->emitError()
@@ -411,7 +397,7 @@ struct ConvertShardShapeOp
<< " but found " << adaptor.getSharding()[0].getDefiningOp();
}
- auto loc = op.getLoc();
+ Location loc = op.getLoc();
Type index = rewriter.getIndexType();
// Fill a vector with the dynamic dims
@@ -436,9 +422,8 @@ struct ConvertShardShapeOp
SymbolTableCollection symbolTableCollection;
auto meshOp = getMesh(sharding, symbolTableCollection);
// For now we only support static mesh shapes
- if (ShapedType::isDynamicShape(meshOp.getShape())) {
- return mlir::failure();
- }
+ if (ShapedType::isDynamicShape(meshOp.getShape()))
+ return failure();
auto splitAxes = sharding.getSplitAxes().getAxes();
Value shardedDimsOffs;
@@ -447,12 +432,8 @@ struct ConvertShardShapeOp
rewriter, loc, sharding.getStaticShardedDimsOffsets(),
sharding.getDynamicShardedDimsOffsets(), index);
if (!tmp.empty())
- shardedDimsOffs =
- rewriter
- .create<tensor::FromElementsOp>(
- loc, RankedTensorType::get({(int64_t)tmp.size()}, index),
- tmp)
- .getResult();
+ shardedDimsOffs = rewriter.create<tensor::FromElementsOp>(
+ loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp);
}
int64_t pos = 0; // position in shardedDimsOffs
@@ -473,17 +454,16 @@ struct ConvertShardShapeOp
if (shardedDimsOffs) {
// If sharded dims offsets are provided, use them to compute the
// sharded shape.
- if (axes.size() > 1)
+ if (axes.size() > 1) {
return op->emitError() << "Only single axis sharding is "
- "supported for each dimension.";
+ << "supported for each dimension.";
+ }
idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
Value off =
- rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx)
- .getResult();
+ rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
idx = rewriter.create<arith::AddIOp>(loc, idx, one);
Value nextOff =
- rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx)
- .getResult();
+ rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
shardShape.emplace_back(sz);
} else {
@@ -510,17 +490,16 @@ struct ConvertShardShapeOp
}
assert(shardShape.size() == shape.size());
rewriter.replaceOp(op, shardShape);
- return mlir::success();
+ return success();
}
};
-struct ConvertUpdateHaloOp
- : public mlir::OpConversionPattern<mlir::mesh::UpdateHaloOp> {
+struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
using OpConversionPattern::OpConversionPattern;
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::UpdateHaloOp op, OpAdaptor adaptor,
- mlir::ConversionPatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
// The input/output memref is assumed to be in C memory order.
// Halos are exchanged as 2 blocks per dimension (one for each side: down
@@ -539,17 +518,17 @@ struct ConvertUpdateHaloOp
if (haloSizes.empty()) {
// no halos -> nothing to do
rewriter.replaceOp(op, adaptor.getDestination());
- return mlir::success();
+ return success();
}
SymbolTableCollection symbolTableCollection;
- auto loc = op.getLoc();
+ Location loc = op.getLoc();
// convert a OpFoldResult into a Value
auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
if (auto value = dyn_cast<Value>(v))
return value;
- return rewriter.create<::mlir::arith::ConstantOp>(
+ return rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(
cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
};
@@ -561,8 +540,8 @@ struct ConvertUpdateHaloOp
// If the destination is a memref, we need to cast it to a tensor
auto tensorType = MemRefType::get(
dstShape, cast<ShapedType>(array.getType()).getElementType());
- array = rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array)
- .getResult();
+ array =
+ rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array);
}
auto rank = cast<ShapedType>(array.getType()).getRank();
auto opSplitAxes = adaptor.getSplitAxes().getAxes();
@@ -570,12 +549,11 @@ struct ConvertUpdateHaloOp
auto meshOp = getMesh(op, symbolTableCollection);
// subviews need Index values
for (auto &sz : haloSizes) {
- if (auto value = dyn_cast<Value>(sz)) {
+ if (auto value = dyn_cast<Value>(sz))
sz =
rewriter
.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value)
.getResult();
- }
}
// most of the offset/size/stride data is the same for all dims
@@ -586,11 +564,10 @@ struct ConvertUpdateHaloOp
// we need the actual shape to compute offsets and sizes
for (auto i = 0; i < rank; ++i) {
auto s = dstShape[i];
- if (ShapedType::isDynamic(s)) {
+ if (ShapedType::isDynamic(s))
shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
- } else {
+ else
shape[i] = rewriter.getIndexAttr(s);
- }
if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
++currHaloDim;
@@ -598,11 +575,9 @@ struct ConvertUpdateHaloOp
offsets[i] = haloSizes[currHaloDim * 2];
// prepare shape and offsets of highest dim's halo exchange
- auto _haloSz =
- rewriter
- .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
- toValue(haloSizes[currHaloDim * 2 + 1]))
- .getResult();
+ Value _haloSz = rewriter.create<arith::AddIOp>(
+ loc, toValue(haloSizes[currHaloDim * 2]),
+ toValue(haloSizes[currHaloDim * 2 + 1]));
// the halo shape of lower dims exlude the halos
dimSizes[i] =
rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
@@ -613,9 +588,9 @@ struct ConvertUpdateHaloOp
}
auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
- auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+ auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr);
auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
- auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+ auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
rewriter.getIndexType());
@@ -625,9 +600,8 @@ struct ConvertUpdateHaloOp
// traverse all split axes from high to low dim
for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
auto splitAxes = opSplitAxes[dim];
- if (splitAxes.empty()) {
+ if (splitAxes.empty())
continue;
- }
assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
// Get the linearized ids of the neighbors (down and up) for the
// given split
@@ -697,7 +671,7 @@ struct ConvertUpdateHaloOp
auto get_i32val = [&](OpFoldResult &v) {
return isa<Value>(v)
? cast<Value>(v)
- : rewriter.create<::mlir::arith::ConstantOp>(
+ : rewriter.create<arith::ConstantOp>(
loc,
rewriter.getI32IntegerAttr(
cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
@@ -730,7 +704,7 @@ struct ConvertUpdateHaloOp
loc, op.getResult().getType(), array,
/*restrict=*/true, /*writable=*/true));
}
- return mlir::success();
+ return success();
}
};
@@ -741,9 +715,9 @@ struct ConvertMeshToMPIPass
/// Run the dialect converter on the module.
void runOnOperation() override {
auto *ctxt = &getContext();
- mlir::RewritePatternSet patterns(ctxt);
+ RewritePatternSet patterns(ctxt);
ConversionTarget target(getContext());
- mlir::TypeConverter typeConverter;
+ TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
// convert mesh::ShardingType to a tuple of RankedTensorTypes
@@ -757,25 +731,22 @@ struct ConvertMeshToMPIPass
results.emplace_back(RankedTensorType::get(shp, i16));
results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
results.emplace_back(RankedTensorType::get(shp, i64));
- return mlir::success();
+ return success();
});
// To extract components a UnrealizedConversionCastOp is expected to define
// the input
typeConverter.addTargetMaterialization(
[&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
Location loc) {
- if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType())) {
+ if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
return SmallVector<Value>();
- }
auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
- if (!castOp) {
+ if (!castOp)
return SmallVector<Value>();
- }
SmallVector<Value> results;
for (auto oprnd : castOp.getInputs()) {
- if (!isa<RankedTensorType>(oprnd.getType())) {
+ if (!isa<RankedTensorType>(oprnd.getType()))
return SmallVector<Value>();
- }
results.emplace_back(oprnd);
}
return results;
@@ -802,8 +773,7 @@ struct ConvertMeshToMPIPass
populateCallOpTypeConversionPattern(patterns, typeConverter);
populateReturnOpTypeConversionPattern(patterns, typeConverter);
- (void)mlir::applyPartialConversion(getOperation(), target,
- std::move(patterns));
+ (void)applyPartialConversion(getOperation(), target, std::move(patterns));
}
};
>From 95582a8ea943a62efac7d343eec6b466521afdad Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Feb 2025 12:56:42 +0100
Subject: [PATCH 8/8] using DLTI instead of global symbol to get optional
static MPI rank
---
mlir/include/mlir/Conversion/Passes.td | 8 +-
mlir/lib/Conversion/MeshToMPI/CMakeLists.txt | 1 +
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 121 ++++++--
.../MeshToMPI/convert-mesh-to-mpi.mlir | 279 +++++++++---------
.../MeshToMPI/convert-shardshape-to-mpi.mlir | 136 ++++-----
5 files changed, 305 insertions(+), 240 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cccdf0a8518bf..6074e0e8d822c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -881,10 +881,10 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
let description = [{
This pass converts communication operations from the Mesh dialect to the
MPI dialect.
- If it finds a global named "static_mpi_rank" it will use that splat value
- instead of calling MPI_Comm_rank. This allows optimizations like constant
- shape propagation and fusion because shard/partition sizes depend on the
- rank.
+ If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
+ use that integer value instead of calling MPI_Comm_rank. This allows
+ optimizations like constant shape propagation and fusion because
+ shard/partition sizes depend on the rank.
}];
let dependentDialects = [
"memref::MemRefDialect",
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
index 95815a683f6d6..15560aa61e145 100644
--- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
+++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMeshToMPI
Core
LINK_LIBS PUBLIC
+ MLIRDLTIDialect
MLIRFuncDialect
MLIRIR
MLIRLinalgTransforms
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 9d638100e637b..84db6d456711c 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -284,28 +285,29 @@ struct ConvertProcessMultiIndexOp
}
};
-struct ConvertProcessLinearIndexOp
+class ConvertProcessLinearIndexOp
: public OpConversionPattern<ProcessLinearIndexOp> {
+ int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
+
+public:
using OpConversionPattern::OpConversionPattern;
+ // Constructor accepting worldRank
+ ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
+ MLIRContext *context, int64_t worldRank_ = -1)
+ : OpConversionPattern(typeConverter, context), worldRank(worldRank_) {}
+
LogicalResult
matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Finds a global named "static_mpi_rank" it will use that splat value.
- // Otherwise it defaults to mpi.comm_rank.
-
Location loc = op.getLoc();
- auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
- if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
- op, rankOpName)) {
- if (auto initTnsr = globalOp.getInitialValueAttr()) {
- auto val = cast<DenseElementsAttr>(initTnsr).getSplatValue<int64_t>();
- rewriter.replaceOp(op,
- rewriter.create<arith::ConstantIndexOp>(loc, val));
- return success();
- }
+ if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
+ return success();
}
+
+ // Otherwise call create mpi::CommRankOp
auto rank =
rewriter
.create<mpi::CommRankOp>(
@@ -397,10 +399,22 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
<< " but found " << adaptor.getSharding()[0].getDefiningOp();
}
+ // Compute the sharded shape by applying the sharding to the input shape.
+ // Without shardedDimsOffsets in the sharding, the shard shape is computed
+ // by dividing the dimension size by the number of shards in that dimension
+ // (which is given by the size of the mesh axes provided in split-axes).
+ // Odd elements get distributed to trailing shards.
+ // If a shardedDimsOffsets is provided, the shard shape is computed by
+ // subtracting the offset of the current shard from the offset of the next
+ // shard.
+
Location loc = op.getLoc();
Type index = rewriter.getIndexType();
- // Fill a vector with the dynamic dims
+ // This is a 1:N conversion because the sharding op is a 1:3 conversion.
+ // The operands in the adaptor are a vector<ValeRange>. For dims and device
+ // we have a 1:1 conversion.
+ // For simpler access fill a vector with the dynamic dims.
SmallVector<Value> dynDims, dynDevice;
for (auto dim : adaptor.getDimsDynamic()) {
// type conversion should be 1:1 for ints
@@ -413,12 +427,13 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
dynDevice.emplace_back(device[0]);
}
- // To keep the code simple, we convert dims/device to values when they are
- // attributes
+ // To keep the code simple, convert dims/device to values when they are
+ // attributes. Count on canonicalization to fold static values.
auto shape = getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
auto multiIdx =
getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
+ // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
SymbolTableCollection symbolTableCollection;
auto meshOp = getMesh(sharding, symbolTableCollection);
// For now we only support static mesh shapes
@@ -426,6 +441,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
return failure();
auto splitAxes = sharding.getSplitAxes().getAxes();
+ // shardedDimsOffsets are optional and might be Values (not attributes).
+ // Also, the shardId might be dynamic which means the position in the
+ // shardedDimsOffsets is not statically known. Create a tensor of the
+ // shardedDimsOffsets and later extract the offsets for computing the
+ // local shard-size.
Value shardedDimsOffs;
{
auto tmp = getMixedAsValues(
@@ -436,18 +456,28 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp);
}
- int64_t pos = 0; // position in shardedDimsOffs
+ // With static mesh shape the sizes of the split axes are known.
+ // Hence the start/pos for each split axes in shardDimsOffsets can be
+ // computed statically.
+ int64_t pos = 0;
SmallVector<Value> shardShape;
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index));
Value one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index));
+
+ // Iterate over the dimensions of the tensor shape, get their split Axes,
+ // and compute the sharded shape.
for (auto [i, dim] : llvm::enumerate(shape)) {
- if (i < splitAxes.size()) { // trailing dimensions might not be annotated
+ // Trailing dimensions might not be annotated.
+ if (i < splitAxes.size()) {
auto axes = splitAxes[i];
+ // The current dimension might not be sharded.
if (!axes.empty()) {
+ // Create a value from the static position in shardDimsOffsets.
Value posVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(pos));
+ // Get the index of the local shard in the mesh axis.
Value idx = multiIdx[axes[0]];
auto _numShards =
collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
@@ -459,6 +489,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
<< "supported for each dimension.";
}
idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
+ // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
Value off =
rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
idx = rewriter.create<arith::AddIOp>(loc, idx, one);
@@ -469,9 +500,10 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
} else {
auto numShards = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(_numShards));
- // compute shard dim size by distributing odd elements to trailing
- // shards sz = dim / numShards + (idx >= (numShards - (dim %
- // numShards)) ? one : zero)
+ // Compute shard dim size by distributing odd elements to trailing
+ // shards:
+ // sz = dim / numShards
+ // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShards);
Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShards);
sz1 = rewriter.create<arith::SubIOp>(loc, numShards, sz1);
@@ -481,9 +513,9 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
shardShape.emplace_back(sz);
}
- pos += _numShards + 1; // add one for the total size
- } // else no sharding in this dimension
- }
+ pos += _numShards + 1; // add one for the total size.
+ } // else no sharding if split axis is empty.
+ } // else no sharding if no split axis
// If no size was added -> no sharding in this dimension.
if (shardShape.size() <= i)
shardShape.emplace_back(dim);
@@ -714,12 +746,31 @@ struct ConvertMeshToMPIPass
/// Run the dialect converter on the module.
void runOnOperation() override {
+ uint64_t worldRank = -1;
+ // Try to get DLTI attribute for MPI:comm_world_rank
+ // If found, set worldRank to the value of the attribute.
+ {
+ auto dltiAttr =
+ dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
+ if (succeeded(dltiAttr)) {
+ if (!isa<IntegerAttr>(dltiAttr.value())) {
+ getOperation()->emitError()
+ << "Expected an integer attribute for MPI:comm_world_rank";
+ return signalPassFailure();
+ }
+ worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
+ }
+ }
+
auto *ctxt = &getContext();
RewritePatternSet patterns(ctxt);
ConversionTarget target(getContext());
- TypeConverter typeConverter;
+ // Define a type converter to convert mesh::ShardingType,
+ // mostly for use in return operations.
+ TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
+
// convert mesh::ShardingType to a tuple of RankedTensorTypes
typeConverter.addConversion(
[](ShardingType type,
@@ -733,16 +784,20 @@ struct ConvertMeshToMPIPass
results.emplace_back(RankedTensorType::get(shp, i64));
return success();
});
- // To extract components a UnrealizedConversionCastOp is expected to define
- // the input
+
+ // To 'extract' components, a UnrealizedConversionCastOp is expected
+ // to define the input
typeConverter.addTargetMaterialization(
[&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
Location loc) {
+ // Expecting a single input.
if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
return SmallVector<Value>();
auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
+ // Expecting an UnrealizedConversionCastOp.
if (!castOp)
return SmallVector<Value>();
+ // Fill a vector with elements of the tuple/castOp.
SmallVector<Value> results;
for (auto oprnd : castOp.getInputs()) {
if (!isa<RankedTensorType>(oprnd.getType()))
@@ -752,12 +807,16 @@ struct ConvertMeshToMPIPass
return results;
});
+ // No mesh dialect should left after conversion...
target.addIllegalDialect<mesh::MeshDialect>();
+ // ...except the global MeshOp
target.addLegalOp<mesh::MeshOp>();
+ // Allow all the stuff that our patterns will convert to
target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
arith::ArithDialect, tensor::TensorDialect,
bufferization::BufferizationDialect,
linalg::LinalgDialect, memref::MemRefDialect>();
+ // Make sure the function signature, calls etc. are legal
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return typeConverter.isSignatureLegal(op.getFunctionType());
});
@@ -765,9 +824,11 @@ struct ConvertMeshToMPIPass
[&](Operation *op) { return typeConverter.isLegal(op); });
patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
- ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp,
- ConvertShardingOp, ConvertGetShardingOp, ConvertShardShapeOp>(
- typeConverter, ctxt);
+ ConvertProcessMultiIndexOp, ConvertGetShardingOp,
+ ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
+ // ConvertProcessLinearIndexOp accepts an optional worldRank
+ patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
+
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
populateCallOpTypeConversionPattern(patterns, typeConverter);
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 6d865ddaf7ce5..4e60c6f0d4e44 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -60,23 +60,24 @@ func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
// -----
// CHECK: mesh.mesh @mesh0
-mesh.mesh @mesh0(shape = 3x4x5)
-memref.global constant @static_mpi_rank : memref<index> = dense<24>
-func.func @process_multi_index() -> (index, index, index) {
- // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
- // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
- %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
- // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
- return %0#0, %0#1, %0#2 : index, index, index
-}
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
+ mesh.mesh @mesh0(shape = 3x4x5)
+ func.func @process_multi_index() -> (index, index, index) {
+ // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+ %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+ // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+ }
-// CHECK-LABEL: func @process_linear_index
-func.func @process_linear_index() -> index {
- // CHECK: %[[c24:.*]] = arith.constant 24 : index
- %0 = mesh.process_linear_index on @mesh0 : index
- // CHECK: return %[[c24]] : index
- return %0 : index
+ // CHECK-LABEL: func @process_linear_index
+ func.func @process_linear_index() -> index {
+ // CHECK: %[[c24:.*]] = arith.constant 24 : index
+ %0 = mesh.process_linear_index on @mesh0 : index
+ // CHECK: return %[[c24]] : index
+ return %0 : index
+ }
}
// -----
@@ -103,132 +104,134 @@ func.func @update_halo_1d_first(
}
// -----
-mesh.mesh @mesh0(shape = 4)
-memref.global "public" constant @static_mpi_rank : memref<index> = dense<1>
-// CHECK-LABEL: func @update_halo_1d_with_zero
-func.func @update_halo_1d_with_zero (
- // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
- %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
- // CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32
- // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
- // CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32
- // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
- // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8
- // CHECK-SAME: to memref<2x120x120xi8>
- // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]]) : memref<2x120x120xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]]) : memref<2x120x120xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
- // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8
- // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8>
- %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
- // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
- return %res : memref<120x120x120xi8>
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
+ mesh.mesh @mesh0(shape = 4)
+ // CHECK-LABEL: func @update_halo_1d_with_zero
+ func.func @update_halo_1d_with_zero (
+ // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
+ %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
+ // CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+ // CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8
+ // CHECK-SAME: to memref<2x120x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8>
+ %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
+ // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
+ return %res : memref<120x120x120xi8>
+ }
}
// -----
-mesh.mesh @mesh0(shape = 3x4x5)
-memref.global constant @static_mpi_rank : memref<index> = dense<24>
-// CHECK-LABEL: func @update_halo_3d
-func.func @update_halo_3d(
- // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
- %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
- // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
- // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
- // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
- // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
- // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
- // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
- // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
- // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
- // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
- // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
- // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
- // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
- // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
- // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
- // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
- // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
- // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
- // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
- // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
- // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
- // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
- // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
- // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
- // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
- // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
- // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
- // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
- // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
- // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
- // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
- %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
- // CHECK: return [[varg0]] : memref<120x120x120xi8>
- return %res : memref<120x120x120xi8>
-}
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
+ mesh.mesh @mesh0(shape = 3x4x5)
+ // CHECK-LABEL: func @update_halo_3d
+ func.func @update_halo_3d(
+ // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
+ %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
+ // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
+ // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
+ // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
+ // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
+ // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
+ // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+ // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+ // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+ // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
+ // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
+ // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
+ // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+ // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+ // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+ // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+ // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+ // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
+ // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
+ // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
+ // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
+ %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
+ // CHECK: return [[varg0]] : memref<120x120x120xi8>
+ return %res : memref<120x120x120xi8>
+ }
-// CHECK-LABEL: func @update_halo_3d_tensor
-func.func @update_halo_3d_tensor(
- // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
- %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
- // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
- // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
- // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
- // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
- // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
- // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
- // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
- // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
- // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
- // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
- // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
- // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
- // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
- // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
- // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
- // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
- // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
- // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
- // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
- // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
- // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
- // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
- // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
- // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
- // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
- // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
- // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
- // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
- // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
- // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
- // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
- // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
- %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
- // CHECK: return [[v1]] : tensor<120x120x120xi8>
- return %res : tensor<120x120x120xi8>
+ // CHECK-LABEL: func @update_halo_3d_tensor
+ func.func @update_halo_3d_tensor(
+ // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
+ %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
+ // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
+ // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
+ // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
+ // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
+ // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
+ // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
+ // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+ // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+ // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+ // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
+ // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
+ // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
+ // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+ // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+ // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+ // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+ // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+ // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
+ // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
+ // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
+ // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
+ // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
+ %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
+ // CHECK: return [[v1]] : tensor<120x120x120xi8>
+ return %res : tensor<120x120x120xi8>
+ }
}
// -----
diff --git a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
index 1d5cedf57ec6c..1cc848333ced2 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
@@ -1,75 +1,75 @@
-// RUN: mlir-opt %s --convert-mesh-to-mpi \
-// RUN: | sed 's/%retval, %rank = mpi.comm_rank : !mpi.retval, i32/%rank = arith.constant 24 : i32/g' \
-// RUN: | mlir-opt -canonicalize \
-// RUN: | FileCheck %s
+// RUN: mlir-opt %s --convert-mesh-to-mpi -canonicalize | FileCheck %s
-// Notice: The sed replacement above defines the linear index to be 24, which is multiindex [1, 0, 4]
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
-// CHECK: mesh.mesh @mesh0
-mesh.mesh @mesh0(shape = 3x4x5)
+ // CHECK: mesh.mesh @mesh0
+ mesh.mesh @mesh0(shape = 3x4x5)
+
+ // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @mesh0
-// all shards are equal
-// CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) {
-func.func @shard_shape_equal() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
- %c9 = arith.constant 9 : index
- %c12 = arith.constant 12 : index
- // CHECK: [[vc3:%.*]] = arith.constant 3 : index
- %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
- // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
- return %1#0, %1#1, %1#2 : index, index, index
-}
+ // all shards are equal
+ // CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) {
+ func.func @shard_shape_equal() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %c9 = arith.constant 9 : index
+ %c12 = arith.constant 12 : index
+ // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+ %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
-// last shard in last dim gets an extra element
-// CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) {
-func.func @shard_shape_odd_1() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
- %c9 = arith.constant 9 : index
- %c12 = arith.constant 12 : index
- // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
- // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
- %1:3 = mesh.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
- // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index
- return %1#0, %1#1, %1#2 : index, index, index
-}
+ // last shard in last dim gets an extra element
+ // CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) {
+ func.func @shard_shape_odd_1() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %c9 = arith.constant 9 : index
+ %c12 = arith.constant 12 : index
+ // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
+ // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+ %1:3 = mesh.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
-// all except first shard in second dim get an extra element
-// CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
-func.func @shard_shape_odd_2() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
- %c9 = arith.constant 9 : index
- // CHECK: [[vc3:%.*]] = arith.constant 3 : index
- %1:3 = mesh.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
- // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
- return %1#0, %1#1, %1#2 : index, index, index
-}
+ // all except first shard in second dim get an extra element
+ // CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
+ func.func @shard_shape_odd_2() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %c9 = arith.constant 9 : index
+ // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+ %1:3 = mesh.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
-// all except first shard in first dim get an extra element
-// CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
-func.func @shard_shape_odd_3() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
- // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
- // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
- %1:3 = mesh.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
- // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index
- return %1#0, %1#1, %1#2 : index, index, index
-}
+ // all except first shard in first dim get an extra element
+ // CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
+ func.func @shard_shape_odd_3() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
+ // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+ %1:3 = mesh.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
-// extract from sharded_dims_offsets
-// CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
-func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
- %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]]
- sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !mesh.sharding
- %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
- %c9 = arith.constant 9 : index
- %c12 = arith.constant 12 : index
- // CHECK: [[vc3:%.*]] = arith.constant 3 : index
- // CHECK: [[vc2:%.*]] = arith.constant 2 : index
- %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
- // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
- return %1#0, %1#1, %1#2 : index, index, index
-}
+ // extract from sharded_dims_offsets
+ // CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
+ func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]]
+ sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %c9 = arith.constant 9 : index
+ %c12 = arith.constant 12 : index
+ // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+ // CHECK: [[vc2:%.*]] = arith.constant 2 : index
+ %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list