[Mlir-commits] [mlir] [mlir][mesh, mpi] More on MeshToMPI (PR #129048)
Christian Ulmann
llvmlistbot at llvm.org
Thu Feb 27 23:35:36 PST 2025
================
@@ -48,59 +76,202 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
for (int i = n - 1; i >= 0; --i) {
multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
- if (i > 0) {
+ if (i > 0)
linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
- }
}
return multiIndex;
}
-// Create operations converting a multi-dimensional index to a linear index
+/// Create operations converting a multi-dimensional index to a linear index.
Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
ValueRange dimensions) {
- auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
- auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
+ Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
for (int i = multiIndex.size() - 1; i >= 0; --i) {
- auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
+ Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
}
return linearIndex;
}
+/// Replace GetShardingOp with related/dependent ShardingOp.
+struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
+ if (!shardOp)
+ return failure();
+ auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
+ if (!shardingOp)
+ return failure();
+
+ rewriter.replaceOp(op, shardingOp.getResult());
+ return success();
+ }
+};
+
+/// Convert a sharding op to a tuple of tensors of its components
+/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
+/// as defined by type converter.
+struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto splitAxes = op.getSplitAxes().getAxes();
+ int64_t maxNAxes = 0;
+ for (auto axes : splitAxes) {
+ maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
+ }
----------------
Dinistro wrote:
```suggestion
for (auto axes : splitAxes)
maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
```
https://github.com/llvm/llvm-project/pull/129048
More information about the Mlir-commits
mailing list