[Mlir-commits] [mlir] [mlir][mesh, mpi] More on MeshToMPI (PR #129048)
Tuomas Kärnä
llvmlistbot at llvm.org
Fri Feb 28 07:54:15 PST 2025
================
@@ -380,23 +380,159 @@ struct ConvertNeighborsLinearIndicesOp
[&](OpBuilder &builder, Location loc) {
SmallVector<Value> tmp = mIdx;
tmp[axes[0]] =
- rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one)
- .getResult();
+ rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
builder.create<scf::YieldOp>(
loc, multiToLinearIndex(loc, rewriter, tmp, dims));
});
rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
- return mlir::success();
+ return success();
}
};
-struct ConvertUpdateHaloOp
- : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
- using OpRewritePattern::OpRewritePattern;
+struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
+ using OpConversionPattern::OpConversionPattern;
- mlir::LogicalResult
- matchAndRewrite(mlir::mesh::UpdateHaloOp op,
- mlir::PatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
+ if (!sharding) {
+ return op->emitError()
+ << "Expected SharingOp as defining op for sharding"
+ << " but found " << adaptor.getSharding()[0].getDefiningOp();
+ }
+
+ // Compute the sharded shape by applying the sharding to the input shape.
+ // Without shardedDimsOffsets in the sharding, the shard shape is computed
----------------
tkarna wrote:
For clarity: "If shardedDimsOffsets is not defined in the sharding ..."
https://github.com/llvm/llvm-project/pull/129048
More information about the Mlir-commits
mailing list