[Mlir-commits] [mlir] [mlir][mesh, mpi] More on MeshToMPI (PR #129048)
Frank Schlimbach
llvmlistbot at llvm.org
Fri Feb 28 01:56:30 PST 2025
https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/129048
>From 2e59db0e7aec204f74a1e385fd618f8a2b466a53 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Feb 2025 17:43:25 +0100
Subject: [PATCH 1/7] type converter for ShardingType, allowing returning a
\!mesh.sharding
---
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 271 +++++++++++++++-
.../MeshToMPI/convert-mesh-to-mpi.mlir | 294 ++++++++++++------
2 files changed, 457 insertions(+), 108 deletions(-)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 48b3764d520c2..5d230a24a6316 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -14,8 +14,13 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -25,6 +30,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "mesh-to-mpi"
@@ -36,10 +42,32 @@ namespace mlir {
} // namespace mlir
using namespace mlir;
-using namespace mlir::mesh;
+using namespace mesh;
namespace {
-// Create operations converting a linear index to a multi-dimensional index
+/// Convert vec of OpFoldResults (ints) into vector of Values.
+static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
+ llvm::ArrayRef<int64_t> statics,
+ ValueRange dynamics,
+ Type type = Type()) {
+ SmallVector<Value> values;
+ auto dyn = dynamics.begin();
+ Type i64 = b.getI64Type();
+ if (!type)
+ type = i64;
+ assert(i64 == type || b.getIndexType() == type);
+ for (auto s : statics) {
+ values.emplace_back(
+ ShapedType::isDynamic(s)
+ ? *(dyn++)
+ : b.create<arith::ConstantOp>(loc, type,
+ i64 == type ? b.getI64IntegerAttr(s)
+ : b.getIndexAttr(s)));
+ }
+ return values;
+};
+
+/// Create operations converting a linear index to a multi-dimensional index.
static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
Value linearIndex,
ValueRange dimensions) {
@@ -72,6 +100,152 @@ Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
return linearIndex;
}
+/// Replace GetShardingOp with related/dependent ShardingOp.
+struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
+ if (!shardOp)
+ return failure();
+ auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
+ if (!shardingOp)
+ return failure();
+
+ rewriter.replaceOp(op, shardingOp.getResult());
+ return success();
+ }
+};
+
+/// Convert a sharding op to a tuple of tensors of its components
+/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
+/// as defined by type converter.
+struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto splitAxes = op.getSplitAxes().getAxes();
+ int64_t maxNAxes = 0;
+ for (auto axes : splitAxes) {
+ maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
+ }
+
+ // To hold the split axes, create empty 2d tensor with shape
+ // {splitAxes.size(), max-size-of-split-groups}.
+ // Set trailing elements for smaller split-groups to -1.
+ Location loc = op.getLoc();
+ auto i16 = rewriter.getI16Type();
+ auto i64 = rewriter.getI64Type();
+ int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
+ Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
+ auto attr = IntegerAttr::get(i16, 0xffff);
+ Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
+ resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
+ .getResult(0);
+
+ // explicitly write values into tensor row by row
+ int64_t strides[] = {1, 1};
+ int64_t nSplits = 0;
+ ValueRange empty = {};
+ for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+ int64_t size = axes.size();
+ if (size > 0)
+ ++nSplits;
+ int64_t offs[] = {(int64_t)i, 0};
+ int64_t sizes[] = {1, size};
+ auto tensorType = RankedTensorType::get({size}, i16);
+ auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
+ auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
+ resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
+ loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
+ }
+
+ // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
+ // Store the halo sizes in the tensor.
+ auto haloSizes =
+ getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
+ adaptor.getDynamicHaloSizes());
+ auto type = RankedTensorType::get({nSplits, 2}, i64);
+ Value resHaloSizes =
+ haloSizes.empty()
+ ? rewriter
+ .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
+ i64)
+ .getResult()
+ : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
+ .getResult();
+
+ // To hold sharded dims offsets, create Tensor with shape {nSplits,
+ // maxSplitSize+1}. Store the offsets in the tensor but set trailing
+ // elements for smaller split-groups to -1. Computing the max size of the
+ // split groups needs using collectiveProcessGroupSize (which needs the
+ // MeshOp)
+ Value resOffsets;
+ if (adaptor.getStaticShardedDimsOffsets().empty()) {
+ resOffsets = rewriter.create<tensor::EmptyOp>(
+ loc, std::array<int64_t, 2>{0, 0}, i64);
+ } else {
+ SymbolTableCollection symbolTableCollection;
+ auto meshOp = getMesh(op, symbolTableCollection);
+ auto maxSplitSize = 0;
+ for (auto axes : splitAxes) {
+ int64_t splitSize =
+ collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ assert(splitSize != ShapedType::kDynamic);
+ maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
+ }
+ assert(maxSplitSize);
+ ++maxSplitSize; // add one for the total size
+
+ resOffsets = rewriter.create<tensor::EmptyOp>(
+ loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
+ resOffsets =
+ rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
+ auto offsets =
+ getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
+ adaptor.getDynamicShardedDimsOffsets());
+ int64_t curr = 0;
+ for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+ int64_t splitSize =
+ collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
+ ++splitSize; // add one for the total size
+ ArrayRef<Value> values(&offsets[curr], splitSize);
+ Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
+ int64_t offs[] = {(int64_t)i, 0};
+ int64_t sizes[] = {1, splitSize};
+ resOffsets = rewriter.create<tensor::InsertSliceOp>(
+ loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
+ curr += splitSize;
+ }
+ }
+
+ // return a tuple of tensors as defined by type converter
+ SmallVector<Type> resTypes;
+ if (failed(getTypeConverter()->convertType(op.getResult().getType(),
+ resTypes)))
+ return failure();
+
+ resSplitAxes =
+ rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
+ resHaloSizes =
+ rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
+ resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets);
+
+ rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
+ op, TupleType::get(op.getContext(), resTypes),
+ ValueRange{resSplitAxes, resHaloSizes, resOffsets});
+
+ return success();
+ }
+};
+
struct ConvertProcessMultiIndexOp
: public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
using OpRewritePattern::OpRewritePattern;
@@ -419,14 +593,95 @@ struct ConvertMeshToMPIPass
/// Run the dialect converter on the module.
void runOnOperation() override {
- auto *ctx = &getContext();
- mlir::RewritePatternSet patterns(ctx);
+ uint64_t worldRank = -1;
+ // Try to get DLTI attribute for MPI:comm_world_rank
+ // If found, set worldRank to the value of the attribute.
+ {
+ auto dltiAttr =
+ dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
+ if (succeeded(dltiAttr)) {
+ if (!isa<IntegerAttr>(dltiAttr.value())) {
+ getOperation()->emitError()
+ << "Expected an integer attribute for MPI:comm_world_rank";
+ return signalPassFailure();
+ }
+ worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
+ }
+ }
- patterns.insert<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
- ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
- ctx);
+ auto *ctxt = &getContext();
+ RewritePatternSet patterns(ctxt);
+ ConversionTarget target(getContext());
+
+ // Define a type converter to convert mesh::ShardingType,
+ // mostly for use in return operations.
+ TypeConverter typeConverter;
+ typeConverter.addConversion([](Type type) { return type; });
+
+ // convert mesh::ShardingType to a tuple of RankedTensorTypes
+ typeConverter.addConversion(
+ [](ShardingType type,
+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+ auto i16 = IntegerType::get(type.getContext(), 16);
+ auto i64 = IntegerType::get(type.getContext(), 64);
+ std::array<int64_t, 2> shp{ShapedType::kDynamic,
+ ShapedType::kDynamic};
+ results.emplace_back(RankedTensorType::get(shp, i16));
+ results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
+ results.emplace_back(RankedTensorType::get(shp, i64));
+ return success();
+ });
+
+ // To 'extract' components, a UnrealizedConversionCastOp is expected
+ // to define the input
+ typeConverter.addTargetMaterialization(
+ [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+ Location loc) {
+ // Expecting a single input.
+ if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
+ return SmallVector<Value>();
+ auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
+ // Expecting an UnrealizedConversionCastOp.
+ if (!castOp)
+ return SmallVector<Value>();
+ // Fill a vector with elements of the tuple/castOp.
+ SmallVector<Value> results;
+ for (auto oprnd : castOp.getInputs()) {
+ if (!isa<RankedTensorType>(oprnd.getType()))
+ return SmallVector<Value>();
+ results.emplace_back(oprnd);
+ }
+ return results;
+ });
- (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
+ // No mesh dialect should left after conversion...
+ target.addIllegalDialect<mesh::MeshDialect>();
+ // ...except the global MeshOp
+ target.addLegalOp<mesh::MeshOp>();
+ // Allow all the stuff that our patterns will convert to
+ target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
+ arith::ArithDialect, tensor::TensorDialect,
+ bufferization::BufferizationDialect,
+ linalg::LinalgDialect, memref::MemRefDialect>();
+ // Make sure the function signature, calls etc. are legal
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ return typeConverter.isSignatureLegal(op.getFunctionType());
+ });
+ target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
+ [&](Operation *op) { return typeConverter.isLegal(op); });
+
+ patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
+ ConvertProcessMultiIndexOp, ConvertGetShardingOp,
+ ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
+ // ConvertProcessLinearIndexOp accepts an optional worldRank
+ patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
+
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+ patterns, typeConverter);
+ populateCallOpTypeConversionPattern(patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+
+ (void)applyPartialConversion(getOperation(), target, std::move(patterns));
}
};
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index c1aef97438bd5..90bd80472c2b9 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -103,106 +103,200 @@ func.func @update_halo_1d_first(
}
// -----
-mesh.mesh @mesh0(shape = 3x4x5)
-memref.global constant @static_mpi_rank : memref<index> = dense<24>
-// CHECK-LABEL: func @update_halo_3d
-func.func @update_halo_3d(
- // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
- %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
- // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32
- // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
- // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
- // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
- // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
- // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
- // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
- // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
- // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
- // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
- // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
- // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
- // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
- // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
- // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
- // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
- // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
- // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
- // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
- // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
- // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
- // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
- // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
- // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
- // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[varg0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
- // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
- // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8>
- // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
- // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
- // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
- %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
- // CHECK: return [[varg0]] : memref<120x120x120xi8>
- return %res : memref<120x120x120xi8>
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 1> } {
+ mesh.mesh @mesh0(shape = 4)
+ // CHECK-LABEL: func @update_halo_1d_with_zero
+ func.func @update_halo_1d_with_zero (
+ // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
+ %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
+ // CHECK: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ // CHECK-NEXT: [[vc0_i32:%.*]] = arith.constant 0 : i32
+ // CHECK-NEXT: [[vc2_i32:%.*]] = arith.constant 2 : i32
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<2x120x120xi8>
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<2x120x120xi8
+ // CHECK-SAME: to memref<2x120x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc2_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc0_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][0, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<2x120x120xi8> to memref<2x120x120xi8
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<2x120x120xi8>
+ %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[0]] halo_sizes = [2, 0] : memref<120x120x120xi8>
+ // CHECK: return [[res:%.*]] : memref<120x120x120xi8>
+ return %res : memref<120x120x120xi8>
+ }
+}
+
+// -----
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
+ mesh.mesh @mesh0(shape = 3x4x5)
+ // CHECK-LABEL: func @update_halo_3d
+ func.func @update_halo_3d(
+ // CHECK-SAME: [[varg0:%.*]]: memref<120x120x120xi8>
+ %arg0 : memref<120x120x120xi8>) -> memref<120x120x120xi8> {
+ // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
+ // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
+ // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
+ // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[varg0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
+ // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
+ // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[varg0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+ // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+ // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[varg0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+ // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
+ // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[varg0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
+ // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
+ // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[varg0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+ // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+ // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+ // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[varg0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+ // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+ // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
+ // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
+ // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[varg0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
+ // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
+ %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : memref<120x120x120xi8>
+ // CHECK: return [[varg0]] : memref<120x120x120xi8>
+ return %res : memref<120x120x120xi8>
+ }
+
+ // CHECK-LABEL: func @update_halo_3d_tensor
+ func.func @update_halo_3d_tensor(
+ // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
+ %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
+ // CHECK-NEXT: [[vc23_i32:%.*]] = arith.constant 23 : i32
+ // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
+ // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
+ // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
+ // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
+ // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
+ // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
+ // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
+ // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
+ // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+ // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
+ // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
+ // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
+ // CHECK-NEXT: [[vsubview_2:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
+ // CHECK-NEXT: memref.copy [[vsubview_2]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
+ // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_3]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
+ // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
+ // CHECK-NEXT: [[valloc_4:%.*]] = memref.alloc() : memref<117x3x120xi8>
+ // CHECK-NEXT: [[vsubview_5:%.*]] = memref.subview [[v0]][1, 113, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>>
+ // CHECK-NEXT: memref.copy [[vsubview_5]], [[valloc_4]] : memref<117x3x120xi8, strided<[14400, 120, 1], offset: 27960>> to memref<117x3x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_4]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_4]] : memref<117x3x120xi8>
+ // CHECK-NEXT: [[valloc_6:%.*]] = memref.alloc() : memref<117x4x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_6]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 116, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+ // CHECK-NEXT: memref.copy [[valloc_6]], [[vsubview_7]] : memref<117x4x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 28320>>
+ // CHECK-NEXT: memref.dealloc [[valloc_6]] : memref<117x4x120xi8>
+ // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<1x120x120xi8>
+ // CHECK-NEXT: mpi.recv([[valloc_8]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
+ // CHECK-NEXT: [[vsubview_9:%.*]] = memref.subview [[v0]][0, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+ // CHECK-NEXT: memref.copy [[valloc_8]], [[vsubview_9]] : memref<1x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1]>>
+ // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<1x120x120xi8>
+ // CHECK-NEXT: [[valloc_10:%.*]] = memref.alloc() : memref<2x120x120xi8>
+ // CHECK-NEXT: [[vsubview_11:%.*]] = memref.subview [[v0]][1, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>>
+ // CHECK-NEXT: memref.copy [[vsubview_11]], [[valloc_10]] : memref<2x120x120xi8, strided<[14400, 120, 1], offset: 14400>> to memref<2x120x120xi8>
+ // CHECK-NEXT: mpi.send([[valloc_10]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
+ // CHECK-NEXT: memref.dealloc [[valloc_10]] : memref<2x120x120xi8>
+ // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8> to tensor<120x120x120xi8>
+ %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
+ // CHECK: return [[v1]] : tensor<120x120x120xi8>
+ return %res : tensor<120x120x120xi8>
+ }
+}
+
+// -----
+mesh.mesh @mesh0(shape = 2x2x4)
+// CHECK-LABEL: func.func @return_sharding(
+// CHECK-SAME: [[varg0:%.*]]: tensor<2x4xf32>) -> (tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
+func.func @return_sharding(%arg0: tensor<2x4xf32>) -> (tensor<2x4xf32>, !mesh.sharding) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] : !mesh.sharding
+ // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<2> : tensor<1xi16>
+ // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+ // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+ // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+ // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+ // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_0]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+ // CHECK-NEXT: [[vinserted_slice_1:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+ // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+ // CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<0x0xi64>
+ // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_1]] : tensor<2x2xi16> to tensor<?x?xi16>
+ // CHECK-NEXT: [[vcast_2:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+ // CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[v3]] : tensor<0x0xi64> to tensor<?x?xi64>
+ // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_2]], [[vcast_3]] : tensor<2x4xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+ return %arg0, %sharding : tensor<2x4xf32>, !mesh.sharding
+}
+
+// CHECK-LABEL: func.func @return_sharding_halos(
+// CHECK-SAME: [[varg0:%.*]]: tensor<6x8xf32>) -> (tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
+func.func @return_sharding_halos(%arg0: tensor<6x8xf32>) -> (tensor<6x8xf32>, !mesh.sharding) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] halo_sizes = [0, 4, 3, 1] : !mesh.sharding
+ // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<{{\[\[}}0, 4], [3, 1]]> : tensor<2x2xi64>
+ // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<2> : tensor<1xi16>
+ // CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+ // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+ // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+ // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+ // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_1]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+ // CHECK-NEXT: [[vinserted_slice_2:%.*]] = tensor.insert_slice [[vcst_0]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+ // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+ // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_2]] : tensor<2x2xi16> to tensor<?x?xi16>
+ // CHECK-NEXT: [[vcast_3:%.*]] = tensor.cast [[vcst]] : tensor<2x2xi64> to tensor<?x?xi64>
+ // CHECK-NEXT: [[vcast_4:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+ // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_3]], [[vcast_4]] : tensor<6x8xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+ return %arg0, %sharding : tensor<6x8xf32>, !mesh.sharding
}
-// CHECK-LABEL: func @update_halo_3d_tensor
-func.func @update_halo_3d_tensor(
- // CHECK-SAME: [[varg0:%.*]]: tensor<120x120x120xi8>
- %arg0 : tensor<120x120x120xi8>) -> tensor<120x120x120xi8> {
- // CHECK: [[vc23_i32:%.*]] = arith.constant 23 : i32
- // CHECK-NEXT: [[vc29_i32:%.*]] = arith.constant 29 : i32
- // CHECK-NEXT: [[vc44_i32:%.*]] = arith.constant 44 : i32
- // CHECK-NEXT: [[vc4_i32:%.*]] = arith.constant 4 : i32
- // CHECK-NEXT: [[vc91_i32:%.*]] = arith.constant 91 : i32
- // CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<120x120x120xi8> to memref<120x120x120xi8>
- // CHECK-NEXT: [[valloc:%.*]] = memref.alloc() : memref<117x113x5xi8>
- // CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[v0]][1, 3, 109] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>>
- // CHECK-NEXT: memref.copy [[vsubview]], [[valloc]] : memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14869>> to memref<117x113x5xi8>
- // CHECK-NEXT: mpi.send([[valloc]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x5xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x5xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_0:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 113, 5] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
- // CHECK-NEXT: memref.copy [[valloc]], [[vsubview_0]] : memref<117x113x5xi8> to memref<117x113x5xi8, strided<[14400, 120, 1], offset: 14760>>
- // CHECK-NEXT: memref.dealloc [[valloc]] : memref<117x113x5xi8>
- // CHECK-NEXT: [[valloc_1:%.*]] = memref.alloc() : memref<117x113x6xi8>
- // CHECK-NEXT: [[vsubview_3:%.*]] = memref.subview [[v0]][1, 3, 5] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>>
- // CHECK-NEXT: memref.copy [[vsubview_3]], [[valloc_1]] : memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14765>> to memref<117x113x6xi8>
- // CHECK-NEXT: mpi.send([[valloc_1]], [[vc91_i32]], [[vc44_i32]]) : memref<117x113x6xi8>, i32, i32
- // CHECK-NEXT: mpi.recv([[valloc_1]], [[vc91_i32]], [[vc4_i32]]) : memref<117x113x6xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_4:%.*]] = memref.subview [[v0]][1, 3, 114] [117, 113, 6] [1, 1, 1] : memref<120x120x120xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
- // CHECK-NEXT: memref.copy [[valloc_1]], [[vsubview_4]] : memref<117x113x6xi8> to memref<117x113x6xi8, strided<[14400, 120, 1], offset: 14874>>
- // CHECK-NEXT: memref.dealloc [[valloc_1]] : memref<117x113x6xi8>
- // CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc() : memref<117x3x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_5]], [[vc91_i32]], [[vc29_i32]]) : memref<117x3x120xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_7:%.*]] = memref.subview [[v0]][1, 0, 0] [117, 3, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
- // CHECK-NEXT: memref.copy [[valloc_5]], [[vsubview_7]] : memref<117x3x120xi8> to memref<117x3x120xi8, strided<[14400, 120, 1], offset: 14400>>
- // CHECK-NEXT: memref.dealloc [[valloc_5]] : memref<117x3x120xi8>
- // CHECK-NEXT: [[valloc_8:%.*]] = memref.alloc() : memref<117x4x120xi8>
- // CHECK-NEXT: [[vsubview_10:%.*]] = memref.subview [[v0]][1, 3, 0] [117, 4, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>>
- // CHECK-NEXT: memref.copy [[vsubview_10]], [[valloc_8]] : memref<117x4x120xi8, strided<[14400, 120, 1], offset: 14760>> to memref<117x4x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_8]], [[vc91_i32]], [[vc29_i32]]) : memref<117x4x120xi8>, i32, i32
- // CHECK-NEXT: memref.dealloc [[valloc_8]] : memref<117x4x120xi8>
- // CHECK-NEXT: [[valloc_11:%.*]] = memref.alloc() : memref<1x120x120xi8>
- // CHECK-NEXT: [[vsubview_12:%.*]] = memref.subview [[v0]][117, 0, 0] [1, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>>
- // CHECK-NEXT: memref.copy [[vsubview_12]], [[valloc_11]] : memref<1x120x120xi8, strided<[14400, 120, 1], offset: 1684800>> to memref<1x120x120xi8>
- // CHECK-NEXT: mpi.send([[valloc_11]], [[vc91_i32]], [[vc23_i32]]) : memref<1x120x120xi8>, i32, i32
- // CHECK-NEXT: memref.dealloc [[valloc_11]] : memref<1x120x120xi8>
- // CHECK-NEXT: [[valloc_13:%.*]] = memref.alloc() : memref<2x120x120xi8>
- // CHECK-NEXT: mpi.recv([[valloc_13]], [[vc91_i32]], [[vc23_i32]]) : memref<2x120x120xi8>, i32, i32
- // CHECK-NEXT: [[vsubview_14:%.*]] = memref.subview [[v0]][118, 0, 0] [2, 120, 120] [1, 1, 1] : memref<120x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
- // CHECK-NEXT: memref.copy [[valloc_13]], [[vsubview_14]] : memref<2x120x120xi8> to memref<2x120x120xi8, strided<[14400, 120, 1], offset: 1699200>>
- // CHECK-NEXT: memref.dealloc [[valloc_13]] : memref<2x120x120xi8>
- // CHECK-NEXT: [[v1:%.*]] = bufferization.to_tensor [[v0]] restrict writable : memref<120x120x120xi8>
- %res = mesh.update_halo %arg0 on @mesh0 split_axes = [[2], [1], [0]] halo_sizes = [1, 2, 3, 4, 5, 6] : tensor<120x120x120xi8>
- // CHECK: return [[v1]] : tensor<120x120x120xi8>
- return %res : tensor<120x120x120xi8>
+// CHECK-LABEL: func.func @return_sharding_offs(
+// CHECK-SAME: [[varg0:%.*]]: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>) {
+func.func @return_sharding_offs(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>, !mesh.sharding) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0, 1], [2]] sharded_dims_offsets = [0, 3, 5, 7, 8, 0, 0, 5, 10, 16] : !mesh.sharding
+ // CHECK-NEXT: [[vcst:%.*]] = arith.constant dense<[0, 0, 5, 10, 16]> : tensor<5xi64>
+ // CHECK-NEXT: [[vcst_0:%.*]] = arith.constant dense<[0, 3, 5, 7, 8]> : tensor<5xi64>
+ // CHECK-NEXT: [[vcm9223372036854775808_i64:%.*]] = arith.constant -9223372036854775808 : i64
+ // CHECK-NEXT: [[vcst_1:%.*]] = arith.constant dense<2> : tensor<1xi16>
+ // CHECK-NEXT: [[vcst_2:%.*]] = arith.constant dense<[0, 1]> : tensor<2xi16>
+ // CHECK-NEXT: [[vcm1_i16:%.*]] = arith.constant -1 : i16
+ // CHECK-NEXT: [[v0:%.*]] = tensor.empty() : tensor<2x2xi16>
+ // CHECK-NEXT: [[v1:%.*]] = linalg.fill ins([[vcm1_i16]] : i16) outs([[v0]] : tensor<2x2xi16>) -> tensor<2x2xi16>
+ // CHECK-NEXT: [[vinserted_slice:%.*]] = tensor.insert_slice [[vcst_2]] into [[v1]][0, 0] [1, 2] [1, 1] : tensor<2xi16> into tensor<2x2xi16>
+ // CHECK-NEXT: [[vinserted_slice_3:%.*]] = tensor.insert_slice [[vcst_1]] into [[vinserted_slice]][1, 0] [1, 1] [1, 1] : tensor<1xi16> into tensor<2x2xi16>
+ // CHECK-NEXT: [[v2:%.*]] = tensor.empty() : tensor<0x0xi64>
+ // CHECK-NEXT: [[v3:%.*]] = tensor.empty() : tensor<2x5xi64>
+ // CHECK-NEXT: [[v4:%.*]] = linalg.fill ins([[vcm9223372036854775808_i64]] : i64) outs([[v3]] : tensor<2x5xi64>) -> tensor<2x5xi64>
+ // CHECK-NEXT: [[vinserted_slice_4:%.*]] = tensor.insert_slice [[vcst_0]] into [[v4]][0, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
+ // CHECK-NEXT: [[vinserted_slice_5:%.*]] = tensor.insert_slice [[vcst]] into [[vinserted_slice_4]][1, 0] [1, 5] [1, 1] : tensor<5xi64> into tensor<2x5xi64>
+ // CHECK-NEXT: [[vcast:%.*]] = tensor.cast [[vinserted_slice_3]] : tensor<2x2xi16> to tensor<?x?xi16>
+ // CHECK-NEXT: [[vcast_6:%.*]] = tensor.cast [[v2]] : tensor<0x0xi64> to tensor<?x?xi64>
+ // CHECK-NEXT: [[vcast_7:%.*]] = tensor.cast [[vinserted_slice_5]] : tensor<2x5xi64> to tensor<?x?xi64>
+ // CHECK-NEXT: return [[varg0]], [[vcast]], [[vcast_6]], [[vcast_7]] : tensor<?x?xf32>, tensor<?x?xi16>, tensor<?x?xi64>, tensor<?x?xi64>
+ return %arg0, %sharding : tensor<?x?xf32>, !mesh.sharding
}
>From 30e6cd18a84cbc6dccfef0857070403f9015fa70 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Feb 2025 17:48:07 +0100
Subject: [PATCH 2/7] fixing halo send/rdv direction, not communication if no
halo
---
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 54 +++++++++++++++------
1 file changed, 38 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 5d230a24a6316..ee05276eb8366 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -410,42 +410,47 @@ struct ConvertUpdateHaloOp
// local data. Because subviews and halos can have mixed dynamic and static
// shapes, OpFoldResults are used whenever possible.
+ auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
+ adaptor.getHaloSizes(), rewriter);
+ if (haloSizes.empty()) {
+ // no halos -> nothing to do
+ rewriter.replaceOp(op, adaptor.getDestination());
+ return success();
+ }
+
SymbolTableCollection symbolTableCollection;
- auto loc = op.getLoc();
+ Location loc = op.getLoc();
// convert a OpFoldResult into a Value
auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
if (auto value = dyn_cast<Value>(v))
return value;
- return rewriter.create<::mlir::arith::ConstantOp>(
+ return rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(
cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
};
- auto dest = op.getDestination();
+ auto dest = adaptor.getDestination();
auto dstShape = cast<ShapedType>(dest.getType()).getShape();
Value array = dest;
if (isa<RankedTensorType>(array.getType())) {
// If the destination is a memref, we need to cast it to a tensor
auto tensorType = MemRefType::get(
dstShape, cast<ShapedType>(array.getType()).getElementType());
- array = rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array)
- .getResult();
+ array =
+ rewriter.create<bufferization::ToMemrefOp>(loc, tensorType, array);
}
auto rank = cast<ShapedType>(array.getType()).getRank();
- auto opSplitAxes = op.getSplitAxes().getAxes();
- auto mesh = op.getMesh();
+ auto opSplitAxes = adaptor.getSplitAxes().getAxes();
+ auto mesh = adaptor.getMesh();
auto meshOp = getMesh(op, symbolTableCollection);
- auto haloSizes =
- getMixedValues(op.getStaticHaloSizes(), op.getHaloSizes(), rewriter);
// subviews need Index values
for (auto &sz : haloSizes) {
- if (auto value = dyn_cast<Value>(sz)) {
+ if (auto value = dyn_cast<Value>(sz))
sz =
rewriter
.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value)
.getResult();
- }
}
// most of the offset/size/stride data is the same for all dims
@@ -530,8 +535,8 @@ struct ConvertUpdateHaloOp
: haloSizes[currHaloDim * 2];
// Check if we need to send and/or receive
// Processes on the mesh borders have only one neighbor
- auto to = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
- auto from = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+ auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
+ auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
auto hasFrom = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, from, zero);
auto hasTo = rewriter.create<arith::CmpIOp>(
@@ -564,8 +569,25 @@ struct ConvertUpdateHaloOp
offsets[dim] = orgOffset;
};
- genSendRecv(false);
- genSendRecv(true);
+ auto get_i32val = [&](OpFoldResult &v) {
+ return isa<Value>(v)
+ ? cast<Value>(v)
+ : rewriter.create<arith::ConstantOp>(
+ loc,
+ rewriter.getI32IntegerAttr(
+ cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
+ };
+
+ for (int i = 0; i < 2; ++i) {
+ Value haloSz = get_i32val(haloSizes[currHaloDim * 2 + i]);
+ auto hasSize = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sgt, haloSz, zero);
+ rewriter.create<scf::IfOp>(loc, hasSize,
+ [&](OpBuilder &builder, Location loc) {
+ genSendRecv(i > 0);
+ builder.create<scf::YieldOp>(loc);
+ });
+ }
// the shape for lower dims include higher dims' halos
dimSizes[dim] = shape[dim];
@@ -583,7 +605,7 @@ struct ConvertUpdateHaloOp
loc, op.getResult().getType(), array,
/*restrict=*/true, /*writable=*/true));
}
- return mlir::success();
+ return success();
}
};
>From 3e7d98dcdec057166e26d5b3303d53250ec49bc9 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Feb 2025 17:52:26 +0100
Subject: [PATCH 3/7] Fixing definition of ShardShapeOp and lowering it to MPI
---
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 22 ++-
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 154 +++++++++++++++++-
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 15 +-
.../Extensions/MeshShardingExtensions.cpp | 19 ++-
.../MeshToMPI/convert-shardshape-to-mpi.mlir | 75 +++++++++
mlir/test/Dialect/Mesh/ops.mlir | 10 +-
.../test/Dialect/Tensor/mesh-spmdization.mlir | 11 +-
7 files changed, 269 insertions(+), 37 deletions(-)
create mode 100644 mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 031e6f63bcb42..f59c4c4c67517 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -345,24 +345,32 @@ def Mesh_GetShardingOp : Mesh_Op<"get_sharding", [Pure]> {
}];
}
-def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [Pure]> {
- let summary = "Get the shard shape of a given process/device.";
+def Mesh_ShardShapeOp : Mesh_Op<"shard_shape", [
+ Pure, AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
+ ]> {
+ let summary = "Get the shard shape for a given process/device.";
let description = [{
- The device/process id is a linearized id of the device/process in the mesh.
+ The device/process id is a multi-index of the device/process in the mesh.
This operation might be used during spmdization when the shard shape depends
on (non-constant) values used in `mesh.sharding`.
}];
let arguments = (ins
- DenseI64ArrayAttr:$shape,
+ DenseI64ArrayAttr:$dims,
+ Variadic<Index>:$dims_dynamic,
Mesh_Sharding:$sharding,
- Index:$device
+ DenseI64ArrayAttr:$device,
+ Variadic<Index>:$device_dynamic
);
let results = (outs Variadic<Index>:$result);
let assemblyFormat = [{
- custom<DimensionList>($shape) $sharding $device attr-dict `:` type($result)
+ `dims` `=` custom<DynamicIndexList>($dims_dynamic, $dims)
+ `sharding` `=` $sharding
+ `device` `=` custom<DynamicIndexList>($device_dynamic, $device)
+ attr-dict `:` type(results)
}];
let builders = [
- OpBuilder<(ins "ArrayRef<int64_t>":$shape, "Value":$sharding, "Value":$device)>
+ OpBuilder<(ins "ArrayRef<int64_t>":$dims, "ArrayRef<Value>":$dims_dyn, "Value":$sharding, "ValueRange":$device)>
];
}
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index ee05276eb8366..3f95b803bc2ed 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -380,23 +380,159 @@ struct ConvertNeighborsLinearIndicesOp
[&](OpBuilder &builder, Location loc) {
SmallVector<Value> tmp = mIdx;
tmp[axes[0]] =
- rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one)
- .getResult();
+ rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
builder.create<scf::YieldOp>(
loc, multiToLinearIndex(loc, rewriter, tmp, dims));
});
rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
- return mlir::success();
+ return success();
}
};
-struct ConvertUpdateHaloOp
- : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
- using OpRewritePattern::OpRewritePattern;
+struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
+ using OpConversionPattern::OpConversionPattern;
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::UpdateHaloOp op,
- mlir::PatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
+ if (!sharding) {
+ return op->emitError()
+ << "Expected SharingOp as defining op for sharding"
+ << " but found " << adaptor.getSharding()[0].getDefiningOp();
+ }
+
+ // Compute the sharded shape by applying the sharding to the input shape.
+ // Without shardedDimsOffsets in the sharding, the shard shape is computed
+ // by dividing the dimension size by the number of shards in that dimension
+ // (which is given by the size of the mesh axes provided in split-axes).
+ // Odd elements get distributed to trailing shards.
+ // If a shardedDimsOffsets is provided, the shard shape is computed by
+ // subtracting the offset of the current shard from the offset of the next
+ // shard.
+
+ Location loc = op.getLoc();
+ Type index = rewriter.getIndexType();
+
+ // This is a 1:N conversion because the sharding op is a 1:3 conversion.
+ // The operands in the adaptor are a vector<ValeRange>. For dims and device
+ // we have a 1:1 conversion.
+ // For simpler access fill a vector with the dynamic dims.
+ SmallVector<Value> dynDims, dynDevice;
+ for (auto dim : adaptor.getDimsDynamic()) {
+ // type conversion should be 1:1 for ints
+ assert(dim.size() == 1);
+ dynDims.emplace_back(dim[0]);
+ }
+ // same for device
+ for (auto device : adaptor.getDeviceDynamic()) {
+ assert(device.size() == 1);
+ dynDevice.emplace_back(device[0]);
+ }
+
+ // To keep the code simple, convert dims/device to values when they are
+ // attributes. Count on canonicalization to fold static values.
+ auto shape = getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
+ auto multiIdx =
+ getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
+
+ // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
+ SymbolTableCollection symbolTableCollection;
+ auto meshOp = getMesh(sharding, symbolTableCollection);
+ // For now we only support static mesh shapes
+ if (ShapedType::isDynamicShape(meshOp.getShape()))
+ return failure();
+
+ auto splitAxes = sharding.getSplitAxes().getAxes();
+ // shardedDimsOffsets are optional and might be Values (not attributes).
+ // Also, the shardId might be dynamic which means the position in the
+ // shardedDimsOffsets is not statically known. Create a tensor of the
+ // shardedDimsOffsets and later extract the offsets for computing the
+ // local shard-size.
+ Value shardedDimsOffs;
+ {
+ auto tmp = getMixedAsValues(
+ rewriter, loc, sharding.getStaticShardedDimsOffsets(),
+ sharding.getDynamicShardedDimsOffsets(), index);
+ if (!tmp.empty())
+ shardedDimsOffs = rewriter.create<tensor::FromElementsOp>(
+ loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp);
+ }
+
+ // With static mesh shape the sizes of the split axes are known.
+ // Hence the start/pos for each split axes in shardDimsOffsets can be
+ // computed statically.
+ int64_t pos = 0;
+ SmallVector<Value> shardShape;
+ Value zero =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index));
+ Value one =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index));
+
+ // Iterate over the dimensions of the tensor shape, get their split Axes,
+ // and compute the sharded shape.
+ for (auto [i, dim] : llvm::enumerate(shape)) {
+ // Trailing dimensions might not be annotated.
+ if (i < splitAxes.size() && !splitAxes[i].empty()) {
+ auto axes = splitAxes[i];
+ // The current dimension might not be sharded.
+ // Create a value from the static position in shardDimsOffsets.
+ Value posVal =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos));
+ // Get the index of the local shard in the mesh axis.
+ Value idx = multiIdx[axes[0]];
+ auto _numShards =
+ collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ if (shardedDimsOffs) {
+ // If sharded dims offsets are provided, use them to compute the
+ // sharded shape.
+ if (axes.size() > 1) {
+ return op->emitError() << "Only single axis sharding is "
+ << "supported for each dimension.";
+ }
+ idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
+ // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
+ Value off =
+ rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
+ idx = rewriter.create<arith::AddIOp>(loc, idx, one);
+ Value nextOff =
+ rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
+ Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
+ shardShape.emplace_back(sz);
+ } else {
+ auto numShards = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIndexAttr(_numShards));
+ // Compute shard dim size by distributing odd elements to trailing
+ // shards:
+ // sz = dim / numShards
+ // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
+ Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShards);
+ Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShards);
+ sz1 = rewriter.create<arith::SubIOp>(loc, numShards, sz1);
+ auto cond = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, idx, sz1);
+ Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero);
+ sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
+ shardShape.emplace_back(sz);
+ }
+ pos += _numShards + 1; // add one for the total size.
+ } // else no sharding if split axis is empty or no split axis
+ // If no size was added -> no sharding in this dimension.
+ if (shardShape.size() <= i)
+ shardShape.emplace_back(dim);
+ }
+ assert(shardShape.size() == shape.size());
+ rewriter.replaceOp(op, shardShape);
+ return success();
+ }
+};
+
+struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
// The input/output memref is assumed to be in C memory order.
// Halos are exchanged as 2 blocks per dimension (one for each side: down
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 304ede195c762..3e9f86fde64f3 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -831,12 +831,19 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_,
// mesh.shard_shape
//===----------------------------------------------------------------------===//
+void ShardShapeOp::getAsmResultNames(
+ function_ref<void(Value, StringRef)> setNameFn) {
+ setNameFn(getResult()[0], "shard_shape");
+}
+
void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
::mlir::OperationState &odsState,
- ::llvm::ArrayRef<int64_t> shape,
- ::mlir::Value sharding, ::mlir::Value device) {
- SmallVector<mlir::Type> resType(shape.size(), odsBuilder.getIndexType());
- build(odsBuilder, odsState, resType, shape, sharding, device);
+ ::llvm::ArrayRef<int64_t> dims,
+ ArrayRef<Value> dims_dyn, ::mlir::Value sharding,
+ ::mlir::ValueRange device) {
+ SmallVector<mlir::Type> resType(dims.size(), odsBuilder.getIndexType());
+ build(odsBuilder, odsState, resType, dims, dims_dyn, sharding,
+ SmallVector<int64_t>(device.size(), ShapedType::kDynamic), device);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index b2acbf20b3fb9..b3d69eb5e1a23 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -50,10 +50,10 @@ struct CreatorOpShardingInterface
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTable,
OpBuilder &builder) const {
- auto shardType = cast<ShapedType>(mesh::shardType(
- op->getResult(0).getType(),
- mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable),
- resultShardings[0]));
+ auto mesh =
+ mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable);
+ auto shardType = cast<ShapedType>(
+ mesh::shardType(op->getResult(0).getType(), mesh, resultShardings[0]));
Operation *newOp = nullptr;
// if the sharding introduces a new dynamic dimension, we take it from
// the dynamic sharding info. For now bail out if it's not
@@ -66,18 +66,19 @@ struct CreatorOpShardingInterface
assert(oldType.getRank() == shardType.getRank());
int currOldOprndNum = -1;
mesh::ShardShapeOp shapeForDevice;
- Value device;
+ ValueRange device;
Operation *newSharding = nullptr;
for (auto i = 0; i < oldType.getRank(); ++i) {
if (!oldType.isDynamicDim(i) && shardType.isDynamicDim(i)) {
if (!newSharding) {
newSharding =
builder.create<ShardingOp>(op->getLoc(), resultShardings[0]);
- device = builder.create<mesh::ProcessLinearIndexOp>(
- op->getLoc(), resultShardings[0].getMesh());
+ device =
+ builder.create<mesh::ProcessMultiIndexOp>(op->getLoc(), mesh)
+ .getResults();
shapeForDevice = builder.create<mesh::ShardShapeOp>(
- op->getLoc(), oldType.getShape(), newSharding->getResult(0),
- device);
+ op->getLoc(), oldType.getShape(), spmdizedOperands,
+ newSharding->getResult(0), device);
}
newOperands.emplace_back(shapeForDevice.getResult()[i]);
} else if (oldType.isDynamicDim(i)) {
diff --git a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
new file mode 100644
index 0000000000000..1cc848333ced2
--- /dev/null
+++ b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s --convert-mesh-to-mpi -canonicalize | FileCheck %s
+
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
+
+ // CHECK: mesh.mesh @mesh0
+ mesh.mesh @mesh0(shape = 3x4x5)
+
+ // Notice: comm_world_rank/linear index 24 is multiindex [1, 0, 4] in @mesh0
+
+ // all shards are equal
+ // CHECK-LABEL: func.func @shard_shape_equal() -> (index, index, index) {
+ func.func @shard_shape_equal() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %c9 = arith.constant 9 : index
+ %c12 = arith.constant 12 : index
+ // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+ %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
+
+ // last shard in last dim gets an extra element
+ // CHECK-LABEL: func.func @shard_shape_odd_1() -> (index, index, index) {
+ func.func @shard_shape_odd_1() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %c9 = arith.constant 9 : index
+ %c12 = arith.constant 12 : index
+ // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
+ // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+ %1:3 = mesh.shard_shape dims = [%c9, %c12, 16] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc3]], [[vc4]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
+
+ // all except first shard in second dim get an extra element
+ // CHECK-LABEL: func.func @shard_shape_odd_2() -> (index, index, index) {
+ func.func @shard_shape_odd_2() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %c9 = arith.constant 9 : index
+ // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+ %1:3 = mesh.shard_shape dims = [%c9, 15, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc3]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
+
+ // all except first shard in first dim get an extra element
+ // CHECK-LABEL: func.func @shard_shape_odd_3() -> (index, index, index) {
+ func.func @shard_shape_odd_3() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]] : !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ // CHECK-DAG: [[vc3:%.*]] = arith.constant 3 : index
+ // CHECK-DAG: [[vc4:%.*]] = arith.constant 4 : index
+ %1:3 = mesh.shard_shape dims = [11, 12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc4]], [[vc3]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
+
+ // extract from sharded_dims_offsets
+ // CHECK-LABEL: func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
+ func.func @shard_shape_sharded_dims_offs() -> (index, index, index) {
+ %sharding = mesh.sharding @mesh0 split_axes = [[0], [1], [2]]
+ sharded_dims_offsets = [0, 1, 4, 9, 0, 2, 6, 12, 12, 0, 3, 6, 9, 12, 15]: !mesh.sharding
+ %0:3 = mesh.process_multi_index on @mesh0 : index, index, index
+ %c9 = arith.constant 9 : index
+ %c12 = arith.constant 12 : index
+ // CHECK: [[vc3:%.*]] = arith.constant 3 : index
+ // CHECK: [[vc2:%.*]] = arith.constant 2 : index
+ %1:3 = mesh.shard_shape dims = [%c9, %c12, 15] sharding = %sharding device = [%0#0, %0#1, %0#2] : index, index, index
+ // CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+ }
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 43a75bf3d8040..3d133f2255772 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -157,10 +157,12 @@ func.func @mesh_shard_shape() {
%c3 = arith.constant 3 : index
// CHECK-NEXT: %[[S:.*]] = mesh.sharding @mesh0 split_axes = {{\[\[}}]] : !mesh.sharding
%s = mesh.sharding @mesh0 split_axes = [[]] : !mesh.sharding
- // CHECK-NEXT: mesh.shard_shape 8x? %[[S]] %[[C3]] : index, index
- %shp:2 = mesh.shard_shape 8x? %s %c3 : index, index
- // CHECK-NEXT: mesh.shard_shape 8x4 %[[S]] %[[C3]] : index, index
- %shp1:2 = mesh.shard_shape 8x4 %s %c3 : index, index
+ // CHECK-NEXT: mesh.shard_shape dims = [8, %[[C3]]
+ // CHECK-SAME: ] sharding = %[[S]] device = [%[[C3]]
+ // CHECK-SAME: ] : index, index
+ %shp:2 = mesh.shard_shape dims = [8, %c3] sharding = %s device = [%c3] : index, index
+ // CHECK-NEXT: mesh.shard_shape dims = [8, 4] sharding = %[[S]] device = [3] : index, index
+ %shp1:2 = mesh.shard_shape dims = [8, 4] sharding = %s device = [3] : index, index
return
}
diff --git a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
index 5443eea83aa2d..01cf5972177f4 100644
--- a/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
+++ b/mlir/test/Dialect/Tensor/mesh-spmdization.mlir
@@ -10,8 +10,9 @@ func.func @tensor_empty_static_sharded_dims_offsets() -> () {
%sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
%sharded= mesh.shard %b to %sharding : tensor<8x16xf32>
// CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
- // CHECK: %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
- // CHECK: %[[V0:.*]]:2 = mesh.shard_shape 8x16 %[[sharding]] %[[proc_linear_idx]] : index, index
+ // CHECK: %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index
+ // CHECK: %[[V0:.*]]:2 = mesh.shard_shape dims = [8, 16] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
+ // CHECK-SAME: ] : index, index
// CHECK: tensor.empty(%[[V0]]#0) : tensor<?x16xf32>
return
@@ -24,8 +25,10 @@ func.func @tensor_empty_dynamic_sharded_dims_offsets(%arg0 : index) -> () {
%sharding = mesh.sharding @mesh_1d_4 split_axes = [[0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
%sharded= mesh.shard %b to %sharding : tensor<8x?xf32>
// CHECK: %[[sharding:.*]] = mesh.sharding @mesh_1d_4 split_axes = {{\[\[}}0]] sharded_dims_offsets = [0, 1, 4, 7, 8] : !mesh.sharding
- // CHECK: %[[proc_linear_idx:.*]] = mesh.process_linear_index on @mesh_1d_4 : index
- // CHECK: %[[V0:.*]]:2 = mesh.shard_shape 8x? %[[sharding]] %[[proc_linear_idx]] : index, index
+ // CHECK: %[[proc_multi_idx:.*]] = mesh.process_multi_index on @mesh_1d_4 : index
+ // CHECK: %[[V0:.*]]:2 = mesh.shard_shape dims = [8, %[[A0]]
+ // CHECK-SAME: ] sharding = %[[sharding]] device = [%[[proc_multi_idx]]
+ // CHECK-SAME: ] : index, index
// CHECK: tensor.empty(%[[V0]]#0, %[[A0]]) : tensor<?x?xf32>
return
>From 3428e5064fd30e30598f74aeb9cb5b6b7c4aec45 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Feb 2025 17:55:27 +0100
Subject: [PATCH 4/7] using DLTI instead of global symbol for static rank in
comm_world
---
mlir/include/mlir/Conversion/Passes.td | 8 ++--
mlir/lib/Conversion/MeshToMPI/CMakeLists.txt | 1 +
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 39 ++++++++++---------
.../MeshToMPI/convert-mesh-to-mpi.mlir | 33 ++++++++--------
4 files changed, 42 insertions(+), 39 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cccdf0a8518bf..6074e0e8d822c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -881,10 +881,10 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
let description = [{
This pass converts communication operations from the Mesh dialect to the
MPI dialect.
- If it finds a global named "static_mpi_rank" it will use that splat value
- instead of calling MPI_Comm_rank. This allows optimizations like constant
- shape propagation and fusion because shard/partition sizes depend on the
- rank.
+ If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
+ use that integer value instead of calling MPI_Comm_rank. This allows
+ optimizations like constant shape propagation and fusion because
+ shard/partition sizes depend on the rank.
}];
let dependentDialects = [
"memref::MemRefDialect",
diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
index 95815a683f6d6..15560aa61e145 100644
--- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
+++ b/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMeshToMPI
Core
LINK_LIBS PUBLIC
+ MLIRDLTIDialect
MLIRFuncDialect
MLIRIR
MLIRLinalgTransforms
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 3f95b803bc2ed..ed8fe80b18efe 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -284,32 +284,33 @@ struct ConvertProcessMultiIndexOp
}
rewriter.replaceOp(op, mIdx);
- return mlir::success();
+ return success();
}
};
-struct ConvertProcessLinearIndexOp
- : public mlir::OpRewritePattern<mlir::mesh::ProcessLinearIndexOp> {
- using OpRewritePattern::OpRewritePattern;
+class ConvertProcessLinearIndexOp
+ : public OpConversionPattern<ProcessLinearIndexOp> {
+ int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::ProcessLinearIndexOp op,
- mlir::PatternRewriter &rewriter) const override {
+public:
+ using OpConversionPattern::OpConversionPattern;
- // Finds a global named "static_mpi_rank" it will use that splat value.
- // Otherwise it defaults to mpi.comm_rank.
+ // Constructor accepting worldRank
+ ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
+ MLIRContext *context, int64_t worldRank_ = -1)
+ : OpConversionPattern(typeConverter, context), worldRank(worldRank_) {}
- auto loc = op.getLoc();
- auto rankOpName = StringAttr::get(op->getContext(), "static_mpi_rank");
- if (auto globalOp = SymbolTable::lookupNearestSymbolFrom<memref::GlobalOp>(
- op, rankOpName)) {
- if (auto initTnsr = globalOp.getInitialValueAttr()) {
- auto val = cast<DenseElementsAttr>(initTnsr).getSplatValue<int64_t>();
- rewriter.replaceOp(op,
- rewriter.create<arith::ConstantIndexOp>(loc, val));
- return mlir::success();
- }
+ LogicalResult
+ matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
+ return success();
}
+
+ // Otherwise call create mpi::CommRankOp
auto rank =
rewriter
.create<mpi::CommRankOp>(
diff --git a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
index 90bd80472c2b9..4e60c6f0d4e44 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-mesh-to-mpi.mlir
@@ -60,23 +60,24 @@ func.func @neighbors_dim2(%arg0 : tensor<120x120x120xi8>) -> (index, index) {
// -----
// CHECK: mesh.mesh @mesh0
-mesh.mesh @mesh0(shape = 3x4x5)
-memref.global constant @static_mpi_rank : memref<index> = dense<24>
-func.func @process_multi_index() -> (index, index, index) {
- // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
- // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
- %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
- // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
- return %0#0, %0#1, %0#2 : index, index, index
-}
+module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
+ mesh.mesh @mesh0(shape = 3x4x5)
+ func.func @process_multi_index() -> (index, index, index) {
+ // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+ %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
+ // CHECK: return %[[c1]], %[[c0]], %[[c4]] : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+ }
-// CHECK-LABEL: func @process_linear_index
-func.func @process_linear_index() -> index {
- // CHECK: %[[c24:.*]] = arith.constant 24 : index
- %0 = mesh.process_linear_index on @mesh0 : index
- // CHECK: return %[[c24]] : index
- return %0 : index
+ // CHECK-LABEL: func @process_linear_index
+ func.func @process_linear_index() -> index {
+ // CHECK: %[[c24:.*]] = arith.constant 24 : index
+ %0 = mesh.process_linear_index on @mesh0 : index
+ // CHECK: return %[[c24]] : index
+ return %0 : index
+ }
}
// -----
>From ce882e60d9933008fcc61b906eab25aeaa05b529 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 27 Feb 2025 17:57:23 +0100
Subject: [PATCH 5/7] cleanup, aligning to conventions
---
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 86 ++++++++++-----------
1 file changed, 39 insertions(+), 47 deletions(-)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index ed8fe80b18efe..49860b96d3685 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -76,23 +76,22 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
for (int i = n - 1; i >= 0; --i) {
multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
- if (i > 0) {
+ if (i > 0)
linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
- }
}
return multiIndex;
}
-// Create operations converting a multi-dimensional index to a linear index
+/// Create operations converting a multi-dimensional index to a linear index.
Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
ValueRange dimensions) {
- auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
- auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
+ Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
for (int i = multiIndex.size() - 1; i >= 0; --i) {
- auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
+ Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
}
@@ -247,34 +246,32 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
};
struct ConvertProcessMultiIndexOp
- : public mlir::OpRewritePattern<mlir::mesh::ProcessMultiIndexOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public OpConversionPattern<ProcessMultiIndexOp> {
+ using OpConversionPattern::OpConversionPattern;
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::ProcessMultiIndexOp op,
- mlir::PatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
// Currently converts its linear index to a multi-dimensional index.
SymbolTableCollection symbolTableCollection;
- auto loc = op.getLoc();
+ Location loc = op.getLoc();
auto meshOp = getMesh(op, symbolTableCollection);
// For now we only support static mesh shapes
- if (ShapedType::isDynamicShape(meshOp.getShape())) {
- return mlir::failure();
- }
+ if (ShapedType::isDynamicShape(meshOp.getShape()))
+ return failure();
SmallVector<Value> dims;
llvm::transform(
meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
});
- auto rank =
- rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp).getResult();
+ Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
// optionally extract subset of mesh axes
- auto axes = op.getAxes();
+ auto axes = adaptor.getAxes();
if (!axes.empty()) {
SmallVector<Value> subIndex;
for (auto axis : axes) {
@@ -319,44 +316,43 @@ class ConvertProcessLinearIndexOp
.getRank();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
rank);
- return mlir::success();
+ return success();
}
};
struct ConvertNeighborsLinearIndicesOp
- : public mlir::OpRewritePattern<mlir::mesh::NeighborsLinearIndicesOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public OpConversionPattern<NeighborsLinearIndicesOp> {
+ using OpConversionPattern::OpConversionPattern;
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::NeighborsLinearIndicesOp op,
- mlir::PatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
// Computes the neighbors indices along a split axis by simply
// adding/subtracting 1 to the current index in that dimension.
// Assigns -1 if neighbor is out of bounds.
- auto axes = op.getSplitAxes();
+ auto axes = adaptor.getSplitAxes();
// For now only single axis sharding is supported
- if (axes.size() != 1) {
- return mlir::failure();
- }
+ if (axes.size() != 1)
+ return failure();
- auto loc = op.getLoc();
+ Location loc = op.getLoc();
SymbolTableCollection symbolTableCollection;
auto meshOp = getMesh(op, symbolTableCollection);
- auto mIdx = op.getDevice();
+ auto mIdx = adaptor.getDevice();
auto orgIdx = mIdx[axes[0]];
SmallVector<Value> dims;
llvm::transform(
meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
});
- auto dimSz = dims[axes[0]];
- auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1).getResult();
- auto minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1).getResult();
- auto atBorder = rewriter.create<arith::CmpIOp>(
+ Value dimSz = dims[axes[0]];
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1);
+ Value atBorder = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, orgIdx,
- rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult());
+ rewriter.create<arith::ConstantIndexOp>(loc, 0));
auto down = rewriter.create<scf::IfOp>(
loc, atBorder,
[&](OpBuilder &builder, Location loc) {
@@ -598,11 +594,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
// we need the actual shape to compute offsets and sizes
for (auto i = 0; i < rank; ++i) {
auto s = dstShape[i];
- if (ShapedType::isDynamic(s)) {
+ if (ShapedType::isDynamic(s))
shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
- } else {
+ else
shape[i] = rewriter.getIndexAttr(s);
- }
if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
++currHaloDim;
@@ -610,11 +605,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
offsets[i] = haloSizes[currHaloDim * 2];
// prepare shape and offsets of highest dim's halo exchange
- auto _haloSz =
- rewriter
- .create<arith::AddIOp>(loc, toValue(haloSizes[currHaloDim * 2]),
- toValue(haloSizes[currHaloDim * 2 + 1]))
- .getResult();
+ Value _haloSz = rewriter.create<arith::AddIOp>(
+ loc, toValue(haloSizes[currHaloDim * 2]),
+ toValue(haloSizes[currHaloDim * 2 + 1]));
// the halo shape of lower dims exlude the halos
dimSizes[i] =
rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
@@ -625,9 +618,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
}
auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
- auto tag = rewriter.create<::mlir::arith::ConstantOp>(loc, tagAttr);
+ auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr);
auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
- auto zero = rewriter.create<::mlir::arith::ConstantOp>(loc, zeroAttr);
+ auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
rewriter.getIndexType());
@@ -637,9 +630,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
// traverse all split axes from high to low dim
for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
auto splitAxes = opSplitAxes[dim];
- if (splitAxes.empty()) {
+ if (splitAxes.empty())
continue;
- }
assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
// Get the linearized ids of the neighbors (down and up) for the
// given split
>From c586b79a2875f923afa4f504708b459ebfc082cb Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Fri, 28 Feb 2025 10:01:03 +0100
Subject: [PATCH 6/7] Apply suggestions from code review
Co-authored-by: Christian Ulmann <christianulmann at gmail.com>
---
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 96 +++++++++----------
.../MeshToMPI/convert-shardshape-to-mpi.mlir | 2 +-
2 files changed, 49 insertions(+), 49 deletions(-)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 49860b96d3685..68516ecf01339 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -45,7 +45,7 @@ using namespace mlir;
using namespace mesh;
namespace {
-/// Convert vec of OpFoldResults (ints) into vector of Values.
+/// Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
llvm::ArrayRef<int64_t> statics,
ValueRange dynamics,
@@ -55,14 +55,14 @@ static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
Type i64 = b.getI64Type();
if (!type)
type = i64;
- assert(i64 == type || b.getIndexType() == type);
+ assert(i64 == type || b.getIndexType() == type && "expected an i64 or an intex type");
for (auto s : statics) {
- values.emplace_back(
- ShapedType::isDynamic(s)
- ? *(dyn++)
- : b.create<arith::ConstantOp>(loc, type,
- i64 == type ? b.getI64IntegerAttr(s)
- : b.getIndexAttr(s)));
+ if (s == ShapedType::kDynamic) {
+ values.emplace_back(*(dyn++));
+ } else {
+ TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
+ values.emplace_back(b.create<arith::ConstantOp>(loc, type, val));
+ }
}
return values;
};
@@ -129,9 +129,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
ConversionPatternRewriter &rewriter) const override {
auto splitAxes = op.getSplitAxes().getAxes();
int64_t maxNAxes = 0;
- for (auto axes : splitAxes) {
+ for (auto axes : splitAxes)
maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
- }
// To hold the split axes, create empty 2d tensor with shape
// {splitAxes.size(), max-size-of-split-groups}.
@@ -139,7 +138,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
Location loc = op.getLoc();
auto i16 = rewriter.getI16Type();
auto i64 = rewriter.getI64Type();
- int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
+ std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
+ maxNAxes};
Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
auto attr = IntegerAttr::get(i16, 0xffff);
Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
@@ -147,15 +147,15 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
.getResult(0);
// explicitly write values into tensor row by row
- int64_t strides[] = {1, 1};
+ std::array<int64_t, 2> strides = {1, 1};
int64_t nSplits = 0;
ValueRange empty = {};
for (auto [i, axes] : llvm::enumerate(splitAxes)) {
int64_t size = axes.size();
if (size > 0)
++nSplits;
- int64_t offs[] = {(int64_t)i, 0};
- int64_t sizes[] = {1, size};
+ std::array<int64_t, 2> offs = {(int64_t)i, 0};
+ std::array<int64_t, 2> sizes = {1, size};
auto tensorType = RankedTensorType::get({size}, i16);
auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
@@ -165,7 +165,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
// To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
// Store the halo sizes in the tensor.
- auto haloSizes =
+ SmallVector<Value> haloSizes =
getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
adaptor.getDynamicHaloSizes());
auto type = RankedTensorType::get({nSplits, 2}, i64);
@@ -190,7 +190,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
} else {
SymbolTableCollection symbolTableCollection;
auto meshOp = getMesh(op, symbolTableCollection);
- auto maxSplitSize = 0;
+ int64_t maxSplitSize = 0;
for (auto axes : splitAxes) {
int64_t splitSize =
collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
@@ -206,7 +206,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
resOffsets =
rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
- auto offsets =
+ SmallVector<Value> offsets =
getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
adaptor.getDynamicShardedDimsOffsets());
int64_t curr = 0;
@@ -217,8 +217,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
++splitSize; // add one for the total size
ArrayRef<Value> values(&offsets[curr], splitSize);
Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
- int64_t offs[] = {(int64_t)i, 0};
- int64_t sizes[] = {1, splitSize};
+ std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
+ std::array<int64_t, 2> sizes = {1, splitSize};
resOffsets = rewriter.create<tensor::InsertSliceOp>(
loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
curr += splitSize;
@@ -275,9 +275,9 @@ struct ConvertProcessMultiIndexOp
if (!axes.empty()) {
SmallVector<Value> subIndex;
for (auto axis : axes) {
- subIndex.push_back(mIdx[axis]);
+ subIndex.emplace_back(mIdx[axis]);
}
- mIdx = subIndex;
+ mIdx = std::move(subIndex);
}
rewriter.replaceOp(op, mIdx);
@@ -294,8 +294,8 @@ class ConvertProcessLinearIndexOp
// Constructor accepting worldRank
ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
- MLIRContext *context, int64_t worldRank_ = -1)
- : OpConversionPattern(typeConverter, context), worldRank(worldRank_) {}
+ MLIRContext *context, int64_t worldRank = -1)
+ : OpConversionPattern(typeConverter, context), worldRank(worldRank) {}
LogicalResult
matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
@@ -429,8 +429,9 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
// To keep the code simple, convert dims/device to values when they are
// attributes. Count on canonicalization to fold static values.
- auto shape = getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
- auto multiIdx =
+ SmallVector<Value> shape =
+ getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
+ SmallVector<Value> multiIdx =
getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
// Get the MeshOp, the mesh shape is needed to compute the sharded shape.
@@ -448,7 +449,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
// local shard-size.
Value shardedDimsOffs;
{
- auto tmp = getMixedAsValues(
+ SmallVector<Value> tmp = getMixedAsValues(
rewriter, loc, sharding.getStaticShardedDimsOffsets(),
sharding.getDynamicShardedDimsOffsets(), index);
if (!tmp.empty())
@@ -478,7 +479,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos));
// Get the index of the local shard in the mesh axis.
Value idx = multiIdx[axes[0]];
- auto _numShards =
+ auto numShards =
collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
if (shardedDimsOffs) {
// If sharded dims offsets are provided, use them to compute the
@@ -497,22 +498,22 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
shardShape.emplace_back(sz);
} else {
- auto numShards = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIndexAttr(_numShards));
+ Value numShardsVal = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIndexAttr(numShards));
// Compute shard dim size by distributing odd elements to trailing
// shards:
// sz = dim / numShards
// + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
- Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShards);
- Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShards);
- sz1 = rewriter.create<arith::SubIOp>(loc, numShards, sz1);
+ Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShardsVal);
+ Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShardsVal);
+ sz1 = rewriter.create<arith::SubIOp>(loc, numShardsVal, sz1);
auto cond = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, idx, sz1);
Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero);
sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
shardShape.emplace_back(sz);
}
- pos += _numShards + 1; // add one for the total size.
+ pos += numShards + 1; // add one for the total size.
} // else no sharding if split axis is empty or no split axis
// If no size was added -> no sharding in this dimension.
if (shardShape.size() <= i)
@@ -698,25 +699,24 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
offsets[dim] = orgOffset;
};
- auto get_i32val = [&](OpFoldResult &v) {
- return isa<Value>(v)
- ? cast<Value>(v)
- : rewriter.create<arith::ConstantOp>(
- loc,
- rewriter.getI32IntegerAttr(
- cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
- };
-
- for (int i = 0; i < 2; ++i) {
- Value haloSz = get_i32val(haloSizes[currHaloDim * 2 + i]);
+ auto doSendRecv = [&](int upOrDown) {
+ OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
+ Value haloSz = dyn_cast<Value>(v);
+ if (!haloSz)
+ haloSz = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(
+ cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
auto hasSize = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, haloSz, zero);
rewriter.create<scf::IfOp>(loc, hasSize,
[&](OpBuilder &builder, Location loc) {
- genSendRecv(i > 0);
+ genSendRecv(upOrDown > 0);
builder.create<scf::YieldOp>(loc);
});
- }
+ };
+
+ doSendRecv(0);
+ doSendRecv(1);
// the shape for lower dims include higher dims' halos
dimSizes[dim] = shape[dim];
@@ -775,8 +775,8 @@ struct ConvertMeshToMPIPass
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
auto i16 = IntegerType::get(type.getContext(), 16);
auto i64 = IntegerType::get(type.getContext(), 64);
- std::array<int64_t, 2> shp{ShapedType::kDynamic,
- ShapedType::kDynamic};
+ std::array<int64_t, 2> shp = {ShapedType::kDynamic,
+ ShapedType::kDynamic};
results.emplace_back(RankedTensorType::get(shp, i16));
results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
results.emplace_back(RankedTensorType::get(shp, i64));
diff --git a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
index 1cc848333ced2..2ef52efa5d030 100644
--- a/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
+++ b/mlir/test/Conversion/MeshToMPI/convert-shardshape-to-mpi.mlir
@@ -72,4 +72,4 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 24> } {
// CHECK: return [[vc3]], [[vc2]], [[vc3]] : index, index, index
return %1#0, %1#1, %1#2 : index, index, index
}
-}
\ No newline at end of file
+}
>From db651b9e0fc04f8b9a1595d7b3e56ed3a1aa26ee Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 28 Feb 2025 10:55:54 +0100
Subject: [PATCH 7/7] --amend
---
mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 68516ecf01339..04d1c0d338218 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -45,7 +45,8 @@ using namespace mlir;
using namespace mesh;
namespace {
-/// Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
+/// Converts a vector of OpFoldResults (ints) into vector of Values of the
+/// provided type.
static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
llvm::ArrayRef<int64_t> statics,
ValueRange dynamics,
@@ -55,7 +56,8 @@ static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
Type i64 = b.getI64Type();
if (!type)
type = i64;
- assert(i64 == type || b.getIndexType() == type && "expected an i64 or an intex type");
+ assert(i64 == type ||
+ b.getIndexType() == type && "expected an i64 or an intex type");
for (auto s : statics) {
if (s == ShapedType::kDynamic) {
values.emplace_back(*(dyn++));
More information about the Mlir-commits
mailing list