[Mlir-commits] [mlir] [mlir][mesh, mpi] More on MeshToMPI (PR #129048)

Christian Ulmann llvmlistbot at llvm.org
Thu Feb 27 23:35:35 PST 2025


================
@@ -48,59 +76,202 @@ static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
 
   for (int i = n - 1; i >= 0; --i) {
     multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
-    if (i > 0) {
+    if (i > 0)
       linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
-    }
   }
 
   return multiIndex;
 }
 
-// Create operations converting a multi-dimensional index to a linear index
+/// Create operations converting a multi-dimensional index to a linear index.
 Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
                          ValueRange dimensions) {
 
-  auto linearIndex = b.create<arith::ConstantIndexOp>(loc, 0).getResult();
-  auto stride = b.create<arith::ConstantIndexOp>(loc, 1).getResult();
+  Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
+  Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
 
   for (int i = multiIndex.size() - 1; i >= 0; --i) {
-    auto off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
+    Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
     linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
     stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
   }
 
   return linearIndex;
 }
 
+/// Replace GetShardingOp with related/dependent ShardingOp.
+struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
+    if (!shardOp)
+      return failure();
+    auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
+    if (!shardingOp)
+      return failure();
+
+    rewriter.replaceOp(op, shardingOp.getResult());
+    return success();
+  }
+};
+
+/// Convert a sharding op to a tuple of tensors of its components
+///   (SplitAxes, HaloSizes, ShardedDimsOffsets)
+/// as defined by type converter.
+struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto splitAxes = op.getSplitAxes().getAxes();
+    int64_t maxNAxes = 0;
+    for (auto axes : splitAxes) {
+      maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
+    }
+
+    // To hold the split axes, create empty 2d tensor with shape
+    // {splitAxes.size(), max-size-of-split-groups}.
+    // Set trailing elements for smaller split-groups to -1.
+    Location loc = op.getLoc();
+    auto i16 = rewriter.getI16Type();
+    auto i64 = rewriter.getI64Type();
+    int64_t shape[] = {static_cast<int64_t>(splitAxes.size()), maxNAxes};
+    Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
+    auto attr = IntegerAttr::get(i16, 0xffff);
+    Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
+    resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
+                       .getResult(0);
+
+    // explicitly write values into tensor row by row
+    int64_t strides[] = {1, 1};
+    int64_t nSplits = 0;
+    ValueRange empty = {};
+    for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+      int64_t size = axes.size();
+      if (size > 0)
+        ++nSplits;
+      int64_t offs[] = {(int64_t)i, 0};
+      int64_t sizes[] = {1, size};
+      auto tensorType = RankedTensorType::get({size}, i16);
+      auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
+      auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
+      resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
+          loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
+    }
+
+    // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
+    // Store the halo sizes in the tensor.
+    auto haloSizes =
+        getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
+                         adaptor.getDynamicHaloSizes());
+    auto type = RankedTensorType::get({nSplits, 2}, i64);
+    Value resHaloSizes =
+        haloSizes.empty()
+            ? rewriter
+                  .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
+                                           i64)
+                  .getResult()
+            : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
+                  .getResult();
+
+    // To hold sharded dims offsets, create Tensor with shape {nSplits,
+    // maxSplitSize+1}. Store the offsets in the tensor but set trailing
+    // elements for smaller split-groups to -1. Computing the max size of the
+    // split groups needs using collectiveProcessGroupSize (which needs the
+    // MeshOp)
+    Value resOffsets;
+    if (adaptor.getStaticShardedDimsOffsets().empty()) {
+      resOffsets = rewriter.create<tensor::EmptyOp>(
+          loc, std::array<int64_t, 2>{0, 0}, i64);
+    } else {
+      SymbolTableCollection symbolTableCollection;
+      auto meshOp = getMesh(op, symbolTableCollection);
+      auto maxSplitSize = 0;
+      for (auto axes : splitAxes) {
+        int64_t splitSize =
+            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+        assert(splitSize != ShapedType::kDynamic);
+        maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
+      }
+      assert(maxSplitSize);
+      ++maxSplitSize; // add one for the total size
+
+      resOffsets = rewriter.create<tensor::EmptyOp>(
+          loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
+      Value zero = rewriter.create<arith::ConstantOp>(
+          loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
+      resOffsets =
+          rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
+      auto offsets =
+          getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
+                           adaptor.getDynamicShardedDimsOffsets());
+      int64_t curr = 0;
+      for (auto [i, axes] : llvm::enumerate(splitAxes)) {
+        int64_t splitSize =
+            collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
+        assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
+        ++splitSize; // add one for the total size
+        ArrayRef<Value> values(&offsets[curr], splitSize);
+        Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
+        int64_t offs[] = {(int64_t)i, 0};
----------------
Dinistro wrote:

```suggestion
        int64_t offs[] = {static_cast<int64_t>(i), 0};
```
Nit: Don't use C-style casts.

https://github.com/llvm/llvm-project/pull/129048


More information about the Mlir-commits mailing list