[Mlir-commits] [mlir] [mlir][mesh, mpi] More on MeshToMPI (PR #129048)
Christian Ulmann
llvmlistbot at llvm.org
Thu Feb 27 23:35:35 PST 2025
================
@@ -48,59 +76,202 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
for (int i = n - 1; i >= 0; --i) {
multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
- if (i > 0) {
+ if (i > 0)
linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
- }
}
return multiIndex;
}
-// Create operations converting a multi-dimensional index to a linear index
+/// Create operations converting a multi-dimensional index to a linear index.
Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
ValueRange dimensions) {
- auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
- auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
+ Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
for (int i = multiIndex.size() - 1; i >= 0; --i) {
- auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
+ Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
}
return linearIndex;
}
+/// Replace GetShardingOp with related/dependent ShardingOp.
+struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
+ if (!shardOp)
+ return failure();
+ auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
+ if (!shardingOp)
+ return failure();
+
+ rewriter.replaceOp(op, shardingOp.getResult());
+ return success();
+ }
+};
+
+/// Convert a sharding op to a tuple of tensors of its components
+/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
+/// as defined by type converter.
+struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto splitAxes = op.getSplitAxes().getAxes();
+ int64_t maxNAxes = 0;
+ for (auto axes : splitAxes) {
+ maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
+ }
+
+ // To hold the split axes, create empty 2d tensor with shape
+ // {splitAxes.size(), max-size-of-split-groups}.
+ // Set trailing elements for smaller split-groups to -1.
+ Location loc = op.getLoc();
+ auto i16 = rewriter.getI16Type();
+ auto i64 = rewriter.getI64Type();
+ int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
+ Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
+ auto attr = IntegerAttr::get(i16, 0xffff);
+ Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
+ resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
+ .getResult(0);
+
+ // explicitly write values into tensor row by row
+ int64_t strides[] = {1, 1};
+ int64_t nSplits = 0;
+ ValueRange empty = {};
+ for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+ int64_t size = axes.size();
+ if (size > 0)
+ ++nSplits;
+ int64_t offs[] = {(int64_t)i, 0};
+ int64_t sizes[] = {1, size};
+ auto tensorType = RankedTensorType::get({size}, i16);
+ auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
+ auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
+ resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
+ loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
+ }
+
+ // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
+ // Store the halo sizes in the tensor.
+ auto haloSizes =
+ getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
+ adaptor.getDynamicHaloSizes());
+ auto type = RankedTensorType::get({nSplits, 2}, i64);
+ Value resHaloSizes =
+ haloSizes.empty()
+ ? rewriter
+ .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
+ i64)
+ .getResult()
+ : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
+ .getResult();
+
+ // To hold sharded dims offsets, create Tensor with shape {nSplits,
+ // maxSplitSize+1}. Store the offsets in the tensor but set trailing
+ // elements for smaller split-groups to -1. Computing the max size of the
+ // split groups needs using collectiveProcessGroupSize (which needs the
+ // MeshOp)
+ Value resOffsets;
+ if (adaptor.getStaticShardedDimsOffsets().empty()) {
+ resOffsets = rewriter.create<tensor::EmptyOp>(
+ loc, std::array<int64_t, 2>{0, 0}, i64);
+ } else {
+ SymbolTableCollection symbolTableCollection;
+ auto meshOp = getMesh(op, symbolTableCollection);
+ auto maxSplitSize = 0;
+ for (auto axes : splitAxes) {
+ int64_t splitSize =
+ collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ assert(splitSize != ShapedType::kDynamic);
+ maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
+ }
+ assert(maxSplitSize);
+ ++maxSplitSize; // add one for the total size
+
+ resOffsets = rewriter.create<tensor::EmptyOp>(
+ loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
+ resOffsets =
+ rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
+ auto offsets =
+ getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
+ adaptor.getDynamicShardedDimsOffsets());
+ int64_t curr = 0;
+ for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+ int64_t splitSize =
+ collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
+ ++splitSize; // add one for the total size
+ ArrayRef<Value> values(&offsets[curr], splitSize);
+ Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
+ int64_t offs[] = {(int64_t)i, 0};
----------------
Dinistro wrote:
```suggestion
int64_t offs[] = {static_cast<int64_t>(i), 0};
```
Nit: Don't use C-style casts.
https://github.com/llvm/llvm-project/pull/129048
More information about the Mlir-commits
mailing list