[Mlir-commits] [mlir] [mlir][mesh, mpi] More on MeshToMPI (PR #129048)

Frank Schlimbach llvmlistbot at llvm.org
Fri Feb 28 01:51:31 PST 2025


================
@@ -206,23 +377,159 @@ struct ConvertNeighborsLinearIndicesOp
         [&](OpBuilder &builder, Location loc) {
           SmallVector<Value> tmp = mIdx;
           tmp[axes[0]] =
-              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one)
-                  .getResult();
+              rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
           builder.create<scf::YieldOp>(
               loc, multiToLinearIndex(loc, rewriter, tmp, dims));
         });
     rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
-    return mlir::success();
+    return success();
   }
 };
 
-struct ConvertUpdateHaloOp
-    : public mlir::OpRewritePattern<mlir::mesh::UpdateHaloOp> {
-  using OpRewritePattern::OpRewritePattern;
+struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
+    if (!sharding) {
+      return op->emitError()
+             << "Expected SharingOp as defining op for sharding"
+             << " but found " << adaptor.getSharding()[0].getDefiningOp();
+    }
+
+    // Compute the sharded shape by applying the sharding to the input shape.
+    // Without shardedDimsOffsets in the sharding, the shard shape is computed
+    // by dividing the dimension size by the number of shards in that dimension
+    // (which is given by the size of the mesh axes provided in split-axes).
+    // Odd elements get distributed to trailing shards.
+    // If a shardedDimsOffsets is provided, the shard shape is computed by
+    // subtracting the offset of the current shard from the offset of the next
+    // shard.
+
+    Location loc = op.getLoc();
+    Type index = rewriter.getIndexType();
+
+    // This is a 1:N conversion because the sharding op is a 1:3 conversion.
+    // The operands in the adaptor are a vector<ValeRange>. For dims and device
+    // we have a 1:1 conversion.
+    // For simpler access fill a vector with the dynamic dims.
+    SmallVector<Value> dynDims, dynDevice;
+    for (auto dim : adaptor.getDimsDynamic()) {
+      // type conversion should be 1:1 for ints
+      assert(dim.size() == 1);
+      dynDims.emplace_back(dim[0]);
+    }
+    // same for device
+    for (auto device : adaptor.getDeviceDynamic()) {
+      assert(device.size() == 1);
+      dynDevice.emplace_back(device[0]);
+    }
 
-  mlir::LogicalResult
-  matchAndRewrite(mlir::mesh::UpdateHaloOp op,
-                  mlir::PatternRewriter &rewriter) const override {
+    // To keep the code simple, convert dims/device to values when they are
+    // attributes. Count on canonicalization to fold static values.
+    auto shape = getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
+    auto multiIdx =
+        getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
+
+    // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
+    SymbolTableCollection symbolTableCollection;
+    auto meshOp = getMesh(sharding, symbolTableCollection);
+    // For now we only support static mesh shapes
+    if (ShapedType::isDynamicShape(meshOp.getShape()))
+      return failure();
+
+    auto splitAxes = sharding.getSplitAxes().getAxes();
+    // shardedDimsOffsets are optional and might be Values (not attributes).
+    // Also, the shardId might be dynamic which means the position in the
+    // shardedDimsOffsets is not statically known. Create a tensor of the
+    // shardedDimsOffsets and later extract the offsets for computing the
+    // local shard-size.
+    Value shardedDimsOffs;
+    {
+      auto tmp = getMixedAsValues(
+          rewriter, loc, sharding.getStaticShardedDimsOffsets(),
+          sharding.getDynamicShardedDimsOffsets(), index);
+      if (!tmp.empty())
+        shardedDimsOffs = rewriter.create<tensor::FromElementsOp>(
+            loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp);
+    }
+
+    // With static mesh shape the sizes of the split axes are known.
+    // Hence the start/pos for each split axes in shardDimsOffsets can be
+    // computed statically.
+    int64_t pos = 0;
+    SmallVector<Value> shardShape;
+    Value zero =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index));
+    Value one =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index));
+
+    // Iterate over the dimensions of the tensor shape, get their split Axes,
+    // and compute the sharded shape.
+    for (auto [i, dim] : llvm::enumerate(shape)) {
+      // Trailing dimensions might not be annotated.
+      if (i < splitAxes.size() && !splitAxes[i].empty()) {
+        auto axes = splitAxes[i];
+        // The current dimension might not be sharded.
+        // Create a value from the static position in shardDimsOffsets.
+        Value posVal =
+            rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos));
+        // Get the index of the local shard in the mesh axis.
+        Value idx = multiIdx[axes[0]];
+        auto _numShards =
+            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+        if (shardedDimsOffs) {
+          // If sharded dims offsets are provided, use them to compute the
+          // sharded shape.
+          if (axes.size() > 1) {
+            return op->emitError() << "Only single axis sharding is "
+                                   << "supported for each dimension.";
+          }
+          idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
+          // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
+          Value off =
+              rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
+          idx = rewriter.create<arith::AddIOp>(loc, idx, one);
+          Value nextOff =
+              rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
+          Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
+          shardShape.emplace_back(sz);
+        } else {
+          auto numShards = rewriter.create<arith::ConstantOp>(
----------------
fschlimb wrote:

ok

https://github.com/llvm/llvm-project/pull/129048


More information about the Mlir-commits mailing list