[llvm] [mlir] [mlir][mesh] Add all-scatter operation (PR #81218)

Boian Petkantchin via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 13 14:37:58 PST 2024


================
@@ -64,9 +70,83 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
     rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
     return success();
   }
+};
+
+struct AllScatterOpLowering
+    : OpRewritePatternWithSymbolTableCollection<AllScatterOp> {
+  using OpRewritePatternWithSymbolTableCollection::
+      OpRewritePatternWithSymbolTableCollection;
+
+  LogicalResult matchAndRewrite(AllScatterOp op,
+                                PatternRewriter &rewriter) const override {
+    MeshOp mesh = getMesh(op, symbolTableCollection);
+    if (!mesh) {
+      return failure();
+    }
+
+    ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+    builder.setInsertionPointAfter(op.getOperation());
+
+    Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
+
+    Operation::result_range processInGroupMultiIndex =
+        builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes())
+            .getResults();
+
+    Operation::result_range processGroupShape =
+        builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes())
+            .getResult();
+    Value processGroupSize =
+        createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
+
+    int64_t scatterAxis = op.getScatterAxis().getSExtValue();
+    Value operandScatterAxisSize =
+        builder.create<tensor::DimOp>(op.getOperand(), scatterAxis);
+    Value operandScatterAxisSizeModProcessGroupSize =
+        builder.create<arith::RemUIOp>(operandScatterAxisSize,
+                                       processGroupSize);
+    Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
+        arith::CmpIPredicate::eq, operandScatterAxisSizeModProcessGroupSize,
+        zero);
+    builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
+                                 "Scattering a tensor with axis size that is "
+                                 "not exactly divisible by the "
+                                 "mesh process group size is not supported.");
+    Value resultScatterAxisSize = builder.create<arith::DivUIOp>(
+        operandScatterAxisSize, processGroupSize);
+    OpFoldResult processInGroupLinearIndex = affine::linearIndexFromShape(
+        llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
+        llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
+
+    // extract slice
----------------
sogartar wrote:

I added some documentation at the beginning of the function describing the algorithm.

https://github.com/llvm/llvm-project/pull/81218


More information about the llvm-commits mailing list