[llvm] [mlir] [mlir][mesh] Add all-scatter operation (PR #81218)
Lei Zhang via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 13 12:01:59 PST 2024
================
@@ -64,9 +70,83 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
return success();
}
+};
+
+struct AllScatterOpLowering
+ : OpRewritePatternWithSymbolTableCollection<AllScatterOp> {
+ using OpRewritePatternWithSymbolTableCollection::
+ OpRewritePatternWithSymbolTableCollection;
+
+ LogicalResult matchAndRewrite(AllScatterOp op,
+ PatternRewriter &rewriter) const override {
+ MeshOp mesh = getMesh(op, symbolTableCollection);
+ if (!mesh) {
+ return failure();
+ }
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ builder.setInsertionPointAfter(op.getOperation());
+
+ Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
+
+ Operation::result_range processInGroupMultiIndex =
+ builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes())
+ .getResults();
+
+ Operation::result_range processGroupShape =
+ builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes())
+ .getResult();
+ Value processGroupSize =
+ createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
+
+ int64_t scatterAxis = op.getScatterAxis().getSExtValue();
+ Value operandScatterAxisSize =
+ builder.create<tensor::DimOp>(op.getOperand(), scatterAxis);
+ Value operandScatterAxisSizeModProcessGroupSize =
+ builder.create<arith::RemUIOp>(operandScatterAxisSize,
+ processGroupSize);
+ Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
+ arith::CmpIPredicate::eq, operandScatterAxisSizeModProcessGroupSize,
+ zero);
+ builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
+ "Scattering a tensor with axis size that is "
+ "not exactly divisible by the "
+ "mesh process group size is not supported.");
+ Value resultScatterAxisSize = builder.create<arith::DivUIOp>(
+ operandScatterAxisSize, processGroupSize);
+ OpFoldResult processInGroupLinearIndex = affine::linearIndexFromShape(
+ llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
+ llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
+
+ // extract slice
----------------
antiagainst wrote:
Would be nice to flesh out the comment here a bit explaining the indeixng logic below--easier for readers to follow.
https://github.com/llvm/llvm-project/pull/81218
More information about the llvm-commits
mailing list