[Mlir-commits] [mlir] [mlir][mesh, mpi] More on MeshToMPI (PR #129048)

Christian Ulmann llvmlistbot at llvm.org
Thu Feb 27 23:35:36 PST 2025


================
@@ -48,59 +76,202 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
 
   for (int i = n - 1; i >= 0; --i) {
     multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
-    if (i > 0) {
+    if (i > 0)
       linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
-    }
   }
 
   return multiIndex;
 }
 
-// Create operations converting a multi-dimensional index to a linear index
+/// Create operations converting a multi-dimensional index to a linear index.
 Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
                          ValueRange dimensions) {
 
-  auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
-  auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
+  Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
+  Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
 
   for (int i = multiIndex.size() - 1; i >= 0; --i) {
-    auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
+    Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
     linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
     stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
   }
 
   return linearIndex;
 }
 
+/// Replace GetShardingOp with related/dependent ShardingOp.
+struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
+    if (!shardOp)
+      return failure();
+    auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
+    if (!shardingOp)
+      return failure();
+
+    rewriter.replaceOp(op, shardingOp.getResult());
+    return success();
+  }
+};
+
+/// Convert a sharding op to a tuple of tensors of its components
+///   (SplitAxes, HaloSizes, ShardedDimsOffsets)
+/// as defined by type converter.
+struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto splitAxes = op.getSplitAxes().getAxes();
+    int64_t maxNAxes = 0;
+    for (auto axes : splitAxes) {
+      maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
+    }
+
+    // To hold the split axes, create empty 2d tensor with shape
+    // {splitAxes.size(), max-size-of-split-groups}.
+    // Set trailing elements for smaller split-groups to -1.
+    Location loc = op.getLoc();
+    auto i16 = rewriter.getI16Type();
+    auto i64 = rewriter.getI64Type();
+    int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
+    Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
+    auto attr = IntegerAttr::get(i16, 0xffff);
+    Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
+    resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
+                       .getResult(0);
+
+    // explicitly write values into tensor row by row
+    int64_t strides[] = {1, 1};
+    int64_t nSplits = 0;
+    ValueRange empty = {};
+    for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+      int64_t size = axes.size();
+      if (size > 0)
+        ++nSplits;
+      int64_t offs[] = {(int64_t)i, 0};
+      int64_t sizes[] = {1, size};
+      auto tensorType = RankedTensorType::get({size}, i16);
+      auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
+      auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
+      resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
+          loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
+    }
+
+    // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
+    // Store the halo sizes in the tensor.
+    auto haloSizes =
----------------
Dinistro wrote:

Nit: I would prefer if you specify the types for such things. My statement is based on https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable

https://github.com/llvm/llvm-project/pull/129048


More information about the Mlir-commits mailing list