[Mlir-commits] [mlir] [mlir][mesh, mpi] More on MeshToMPI (PR #129048)
Frank Schlimbach
llvmlistbot at llvm.org
Fri Feb 28 01:51:31 PST 2025
================
@@ -206,23 +377,159 @@ struct ConvertNeighborsLinearIndicesOp
[&](OpBuilder &builder, Location loc) {
SmallVector<Value> tmp = mIdx;
tmp[axes[0]] =
- rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one)
- .getResult();
+ rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
builder.create<scf::YieldOp>(
loc, multiToLinearIndex(loc, rewriter, tmp, dims));
});
rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
- return mlir::success();
+ return success();
}
};
-struct ConvertUpdateHaloOp
- : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
- using OpRewritePattern::OpRewritePattern;
+struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
+ if (!sharding) {
+ return op->emitError()
+ << "Expected SharingOp as defining op for sharding"
+ << " but found " << adaptor.getSharding()[0].getDefiningOp();
+ }
+
+ // Compute the sharded shape by applying the sharding to the input shape.
+ // Without shardedDimsOffsets in the sharding, the shard shape is computed
+ // by dividing the dimension size by the number of shards in that dimension
+ // (which is given by the size of the mesh axes provided in split-axes).
+ // Odd elements get distributed to trailing shards.
+ // If a shardedDimsOffsets is provided, the shard shape is computed by
+ // subtracting the offset of the current shard from the offset of the next
+ // shard.
+
+ Location loc = op.getLoc();
+ Type index = rewriter.getIndexType();
+
+ // This is a 1:N conversion because the sharding op is a 1:3 conversion.
+ // The operands in the adaptor are a vector<ValeRange>. For dims and device
+ // we have a 1:1 conversion.
+ // For simpler access fill a vector with the dynamic dims.
+ SmallVector<Value> dynDims, dynDevice;
+ for (auto dim : adaptor.getDimsDynamic()) {
+ // type conversion should be 1:1 for ints
+ assert(dim.size() == 1);
+ dynDims.emplace_back(dim[0]);
+ }
+ // same for device
+ for (auto device : adaptor.getDeviceDynamic()) {
+ assert(device.size() == 1);
+ dynDevice.emplace_back(device[0]);
+ }
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::UpdateHaloOp op,
- mlir::PatternRewriter &rewriter) const override {
+ // To keep the code simple, convert dims/device to values when they are
+ // attributes. Count on canonicalization to fold static values.
+ auto shape = getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
+ auto multiIdx =
+ getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
+
+ // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
+ SymbolTableCollection symbolTableCollection;
+ auto meshOp = getMesh(sharding, symbolTableCollection);
+ // For now we only support static mesh shapes
+ if (ShapedType::isDynamicShape(meshOp.getShape()))
+ return failure();
+
+ auto splitAxes = sharding.getSplitAxes().getAxes();
+ // shardedDimsOffsets are optional and might be Values (not attributes).
+ // Also, the shardId might be dynamic which means the position in the
+ // shardedDimsOffsets is not statically known. Create a tensor of the
+ // shardedDimsOffsets and later extract the offsets for computing the
+ // local shard-size.
+ Value shardedDimsOffs;
+ {
+ auto tmp = getMixedAsValues(
+ rewriter, loc, sharding.getStaticShardedDimsOffsets(),
+ sharding.getDynamicShardedDimsOffsets(), index);
+ if (!tmp.empty())
+ shardedDimsOffs = rewriter.create<tensor::FromElementsOp>(
+ loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp);
+ }
+
+ // With static mesh shape the sizes of the split axes are known.
+ // Hence the start/pos for each split axes in shardDimsOffsets can be
+ // computed statically.
+ int64_t pos = 0;
+ SmallVector<Value> shardShape;
+ Value zero =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index));
+ Value one =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index));
+
+ // Iterate over the dimensions of the tensor shape, get their split Axes,
+ // and compute the sharded shape.
+ for (auto [i, dim] : llvm::enumerate(shape)) {
+ // Trailing dimensions might not be annotated.
+ if (i < splitAxes.size() && !splitAxes[i].empty()) {
+ auto axes = splitAxes[i];
+ // The current dimension might not be sharded.
+ // Create a value from the static position in shardDimsOffsets.
+ Value posVal =
+ rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos));
+ // Get the index of the local shard in the mesh axis.
+ Value idx = multiIdx[axes[0]];
+ auto _numShards =
+ collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+ if (shardedDimsOffs) {
+ // If sharded dims offsets are provided, use them to compute the
+ // sharded shape.
+ if (axes.size() > 1) {
+ return op->emitError() << "Only single axis sharding is "
+ << "supported for each dimension.";
+ }
+ idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
+ // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
+ Value off =
+ rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
+ idx = rewriter.create<arith::AddIOp>(loc, idx, one);
+ Value nextOff =
+ rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
+ Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
+ shardShape.emplace_back(sz);
+ } else {
+ auto numShards = rewriter.create<arith::ConstantOp>(
----------------
fschlimb wrote:
ok
https://github.com/llvm/llvm-project/pull/129048
More information about the Mlir-commits
mailing list