[Mlir-commits] [mlir] [mlir][mesh, mpi] More on MeshToMPI (PR #129048)

Tuomas Kärnä llvmlistbot at llvm.org
Fri Feb 28 07:54:15 PST 2025


================
@@ -380,23 +380,159 @@ struct ConvertNeighborsLinearIndicesOp
         [&](OpBuilder &builder, Location loc) {
           SmallVector<Value> tmp = mIdx;
           tmp[axes[0]] =
-              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one)
-                  .getResult();
+              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
           builder.create<scf::YieldOp>(
               loc, multiToLinearIndex(loc, rewriter, tmp, dims));
         });
     rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
-    return mlir::success();
+    return success();
   }
 };
 
-struct ConvertUpdateHaloOp
-    : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
-  using OpRewritePattern::OpRewritePattern;
+struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::UpdateHaloOp op,
-                  mlir::PatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
+    if (!sharding) {
+      return op->emitError()
+             << "Expected SharingOp as defining op for sharding"
+             << " but found " << adaptor.getSharding()[0].getDefiningOp();
+    }
+
+    // Compute the sharded shape by applying the sharding to the input shape.
+    // Without shardedDimsOffsets in the sharding, the shard shape is computed
----------------
tkarna wrote:

For clarity: "If shardedDimsOffsets is not defined in the sharding ..."

https://github.com/llvm/llvm-project/pull/129048


More information about the Mlir-commits mailing list