[Mlir-commits] [mlir] [mlir][mesh, mpi] More on MeshToMPI (PR #129048)

Frank Schlimbach llvmlistbot at llvm.org
Fri Feb 28 01:51:32 PST 2025


================
@@ -419,14 +744,95 @@ struct ConvertMeshToMPIPass
 
   /// Run the dialect converter on the module.
   void runOnOperation() override {
-    auto *ctx = &getContext();
-    mlir::RewritePatternSet patterns(ctx);
+    uint64_t worldRank = -1;
+    // Try to get DLTI attribute for MPI:comm_world_rank
+    // If found, set worldRank to the value of the attribute.
+    {
+      auto dltiAttr =
+          dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
+      if (succeeded(dltiAttr)) {
+        if (!isa<IntegerAttr>(dltiAttr.value())) {
+          getOperation()->emitError()
+              << "Expected an integer attribute for MPI:comm_world_rank";
+          return signalPassFailure();
+        }
+        worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
+      }
+    }
+
+    auto *ctxt = &getContext();
+    RewritePatternSet patterns(ctxt);
+    ConversionTarget target(getContext());
+
+    // Define a type converter to convert mesh::ShardingType,
+    // mostly for use in return operations.
+    TypeConverter typeConverter;
+    typeConverter.addConversion([](Type type) { return type; });
+
+    // convert mesh::ShardingType to a tuple of RankedTensorTypes
+    typeConverter.addConversion(
+        [](ShardingType type,
+           SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
+          auto i16 = IntegerType::get(type.getContext(), 16);
+          auto i64 = IntegerType::get(type.getContext(), 64);
+          std::array<int64_t, 2> shp{ShapedType::kDynamic,
----------------
fschlimb wrote:

ok

https://github.com/llvm/llvm-project/pull/129048


More information about the Mlir-commits mailing list