[Mlir-commits] [mlir] [mlir][mesh, mpi] More on MeshToMPI (PR #129048)
Frank Schlimbach
llvmlistbot at llvm.org
Fri Feb 28 01:51:32 PST 2025
================
@@ -419,14 +744,95 @@ struct ConvertMeshToMPIPass
/// Run the dialect converter on the module.
void runOnOperation() override {
- auto *ctx = &getContext();
- mlir::RewritePatternSet patterns(ctx);
+ uint64_t worldRank = -1;
+ // Try to get DLTI attribute for MPI:comm_world_rank
+ // If found, set worldRank to the value of the attribute.
+ {
+ auto dltiAttr =
+ dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
+ if (succeeded(dltiAttr)) {
+ if (!isa<IntegerAttr>(dltiAttr.value())) {
+ getOperation()->emitError()
+ << "Expected an integer attribute for MPI:comm_world_rank";
+ return signalPassFailure();
+ }
+ worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
+ }
+ }
+
+ auto *ctxt = &getContext();
+ RewritePatternSet patterns(ctxt);
+ ConversionTarget target(getContext());
+
+ // Define a type converter to convert mesh::ShardingType,
+ // mostly for use in return operations.
+ TypeConverter typeConverter;
+ typeConverter.addConversion([](Type type) { return type; });
+
+ // convert mesh::ShardingType to a tuple of RankedTensorTypes
+ typeConverter.addConversion(
+ [](ShardingType type,
+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+ auto i16 = IntegerType::get(type.getContext(), 16);
+ auto i64 = IntegerType::get(type.getContext(), 64);
+ std::array<int64_t, 2> shp{ShapedType::kDynamic,
----------------
fschlimb wrote:
ok
https://github.com/llvm/llvm-project/pull/129048
More information about the Mlir-commits
mailing list