[llvm] [mlir] [mlir][mesh] Add all-scatter operation (PR #81218)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 8 18:59:09 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Boian Petkantchin (sogartar)
<details>
<summary>Changes</summary>
This op is the inverse of all-gather. It is useful to have an explicit concise representation instead of having a blob of slicing logic.
Add lowering for the op that slices from the tensor based on the in-group process index.
Make resharding generate an all-scatter instead of inserting the slicing logic directly.
---
Patch is 55.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81218.diff
26 Files Affected:
- (modified) mlir/include/mlir/Dialect/Affine/Utils.h (+7)
- (modified) mlir/include/mlir/Dialect/Arith/Utils/Utils.h (+17)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+6)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td (+63)
- (modified) mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h (+18-1)
- (modified) mlir/include/mlir/IR/Builders.h (+3)
- (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+27)
- (modified) mlir/lib/Dialect/Arith/Utils/Utils.cpp (+19)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+61-3)
- (modified) mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt (+2)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (+5-11)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+8-76)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (+125-11)
- (added) mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h (+35)
- (modified) mlir/lib/IR/Builders.cpp (+18)
- (added) mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir (+74)
- (modified) mlir/test/Dialect/Mesh/canonicalization.mlir (+13)
- (modified) mlir/test/Dialect/Mesh/invalid.mlir (+52)
- (modified) mlir/test/Dialect/Mesh/ops.mlir (+24)
- (modified) mlir/test/Dialect/Mesh/resharding-spmdization.mlir (+6-25)
- (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+6-8)
- (modified) mlir/test/lib/Dialect/Mesh/CMakeLists.txt (+1-1)
- (added) mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp (+79)
- (removed) mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp (-54)
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2-2)
- (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index b3ccbff3002fb1..441ce1952210bb 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -15,12 +15,14 @@
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/IR/OpDefinition.h"
#include <optional>
namespace mlir {
class DominanceInfo;
class Operation;
class PostDominanceInfo;
+class ImplicitLocOpBuilder;
namespace func {
class FuncOp;
@@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
Value linearIndex,
ArrayRef<Value> basis);
+// Generate IR that extracts the linear index form a multi-index according to
+// a shape.
+OpFoldResult linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+ ArrayRef<OpFoldResult> shape,
+ ImplicitLocOpBuilder &builder);
/// Ensure that all operations that could be executed after `start`
/// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 402bd196f0736a..2111a7c5810294 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -20,6 +20,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
+#include "llvm/ADT/ArrayRef.h"
namespace mlir {
@@ -81,6 +82,22 @@ struct ArithBuilder {
OpBuilder &b;
Location loc;
};
+
+namespace arith {
+
+// Build the product of a sequence.
+// If values = (v0, v1, ..., vn) than the returned
+// value is v0 * v1 * ... * vn.
+// All values must have the same type.
+//
+// The version without `resultType` must contain at least one element in values.
+// Then the result will have the same type as the elements in `values`.
+// If `values` is empty in the version with `resultType` returns 1 with type
+// `resultType`.
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+ Type resultType);
+} // namespace arith
} // namespace mlir
#endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9154e6fd803102..fb9425b96e68e2 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -92,6 +92,12 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
return res;
}
+template <typename MeshAxesRange>
+int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
+ return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
+ mesh.getShape());
+}
+
// Get the size of a sharded dimension.
inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index da372706ec724c..7792aac784d4be 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -96,6 +96,7 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
let builders = [
OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef<MeshAxis>":$axes)>,
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
];
}
@@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
let hasCanonicalizer = 1;
}
+def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
+ Pure,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank
+ ]> {
+ let summary = "All-scatter over a device mesh. This is the inverse of all-gather.";
+ let description = [{
+ Scatter along the `scatter_axis` tensor axis.
+ This operation can be taught of as the inverse of all-gather.
+ Technically, it is not required that all processes have the same input tensor.
+ Each process will slice a piece of its local tensor based on its in-group device index.
+ The operation does not communicate data between devices.
+
+ Example:
+ ```mlir
+ mesh.mesh @mesh0(shape = 2x2)
+ ...
+ %1 = mesh.all_scatter %0 on @mesh0 mesh_axes = [1] scatter_axis = 1
+ : tensor<2x4xi8> -> tensor<2x2xi8>
+ ```
+ Input:
+ ```
+ +-------------+
+ | 1 2 5 6 | <- devices (0, 0) and (0, 1)
+ | 3 4 7 8 |
+ +-------------+
+ | 9 10 13 14 | <- devices (1, 0) and (1, 1)
+ | 11 12 15 16 |
+ +-------------+
+ ```
+ Result:
+ ```
+ gather tensor
+ axis 1
+ ------------>
+ +-------+-------+
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
+ | 3 4 | 7 8 |
+ +-------+-------+
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
+ | 11 12 | 15 16 |
+ +-------+-------+
+ ```
+ }];
+ let arguments = !con(commonArgs, (ins
+ AnyNon0RankedTensor:$input,
+ IndexAttr:$scatter_axis
+ ));
+ let results = (outs
+ AnyNon0RankedTensor:$result
+ );
+ let assemblyFormat = [{
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `scatter_axis` `=` $scatter_axis
+ attr-dict `:` type($input) `->` type($result)
+ }];
+ let hasCanonicalizer = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>,
+ OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>
+ ];
+}
+
def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
Pure,
SameOperandsAndResultElementType,
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index 10a965daac71b9..d398bdd65330b1 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -9,17 +9,34 @@
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+
namespace mlir {
class RewritePatternSet;
class SymbolTableCollection;
class DialectRegistry;
+class ImplicitLocOpBuilder;
namespace mesh {
void processMultiIndexOpLoweringPopulatePatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
-
void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry);
+void allScatterOpLoweringPopulatePatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void allScatterOpLoweringRegisterDialects(DialectRegistry ®istry);
+
+void populateAllOpLoweringPatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void registerAllOpLoweringDialects(DialectRegistry ®istry);
+
+TypedValue<IndexType>
+createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+ ImplicitLocOpBuilder &builder);
+
} // namespace mesh
} // namespace mlir
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 2fe1495b2b593b..43b6d2b3841690 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -118,6 +118,9 @@ class Builder {
// supports boolean, integer, and 16-/32-/64-bit float types, and vector or
// ranked tensor of them. Returns null attribute otherwise.
TypedAttr getZeroAttr(Type type);
+ // Returns a 1-valued attribute of the given `type`.
+ // Type constraints are the same as `getZeroAttr`.
+ TypedAttr getOneAttr(Type type);
// Convenience methods for fixed types.
FloatAttr getF16FloatAttr(float value);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 4d4adb94a9fc8d..41bcd2b9f33e6b 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -20,9 +20,11 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <optional>
@@ -1869,3 +1871,28 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
results.push_back(residual);
return results;
}
+
+OpFoldResult
+mlir::affine::linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+ ArrayRef<OpFoldResult> shape,
+ ImplicitLocOpBuilder &builder) {
+ assert(multiIndex.size() == shape.size());
+ SmallVector<AffineExpr> shapeAffine;
+ for (size_t i = 0; i < shape.size(); ++i) {
+ shapeAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
+ }
+
+ SmallVector<AffineExpr> stridesAffine = computeStrides(shapeAffine);
+ SmallVector<OpFoldResult> strides;
+ strides.reserve(stridesAffine.size());
+ llvm::transform(stridesAffine, std::back_inserter(strides),
+ [&builder, &shape](AffineExpr strideExpr) {
+ return affine::makeComposedFoldedAffineApply(
+ builder, builder.getLoc(), strideExpr, shape);
+ });
+
+ auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
+ OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
+ return affine::makeComposedFoldedAffineApply(
+ builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
+}
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index bf274d4ae27ed8..999cdbc5c10008 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "llvm/ADT/SmallBitVector.h"
+#include <numeric>
using namespace mlir;
@@ -262,3 +263,21 @@ Value ArithBuilder::slt(Value lhs, Value rhs) {
Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
}
+
+namespace mlir::arith {
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
+ return createProduct(builder, loc, values, values.front().getType());
+}
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+ Type resultType) {
+ Value one = builder.create<ConstantOp>(loc, resultType,
+ builder.getOneAttr(resultType));
+ ArithBuilder arithBuilder(builder, loc);
+ return std::accumulate(
+ values.begin(), values.end(), one,
+ [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
+}
+
+} // namespace mlir::arith
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index a65b8f2e5a2376..762725d2c56e66 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -252,14 +252,20 @@ MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshOp mesh) {
+ build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
+}
+
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ MeshOp mesh, ArrayRef<MeshAxis> axes) {
build(odsBuilder, odsState,
- SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
- mesh.getSymName(),
- MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
+ SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
+ odsBuilder.getIndexType()),
+ mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
}
void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef mesh, ArrayRef<MeshAxis> axes) {
+ assert(!axes.empty());
build(odsBuilder, odsState,
SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
@@ -590,6 +596,22 @@ static LogicalResult verifyScatterOperandAndResultShape(
return success();
}
+static RankedTensorType scatterResultType(Type operandType, MeshOp mesh,
+ ArrayRef<MeshAxis> meshAxes,
+ int64_t scatterAxis) {
+ RankedTensorType operandRankedTensorType =
+ operandType.cast<RankedTensorType>();
+ DimensionSize operandScatterAxisSize =
+ operandRankedTensorType.getShape()[scatterAxis];
+ SmallVector<int64_t> resultShape =
+ llvm::to_vector(operandRankedTensorType.getShape());
+
+ resultShape[scatterAxis] =
+ operandScatterAxisSize /
+ DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
+ return operandRankedTensorType.clone(resultShape);
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_gather op
//===----------------------------------------------------------------------===//
@@ -625,6 +647,42 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
}
+//===----------------------------------------------------------------------===//
+// mesh.all_scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AllScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return verifyScatterOperandAndResultShape(
+ getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+ mesh.value().getShape());
+}
+
+void AllScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<AllScatterOp>>(context);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
+ int64_t scatterAxis) {
+ Type resultType =
+ scatterResultType(input.getType(), mesh, meshAxes, scatterAxis);
+ build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
+ scatterAxis);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ Type resultType, Value input, StringRef mesh,
+ ArrayRef<MeshAxis> meshAxes, int64_t scatterAxis) {
+ build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
+ APInt(sizeof(scatterAxis) * CHAR_BIT, scatterAxis));
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_to_all op
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index dccb75848c94f0..28af820440076c 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -13,7 +13,9 @@ add_mlir_dialect_library(MLIRMeshTransforms
LINK_LIBS PUBLIC
MLIRAffineDialect
+ MLIRAffineUtils
MLIRArithDialect
+ MLIRArithUtils
MLIRControlFlowDialect
MLIRFuncDialect
MLIRIR
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index c0273cdaef7144..7fcac2312444f3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "TransformsDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -16,7 +17,6 @@
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
-#include <iterator>
#include <numeric>
#include <utility>
@@ -56,13 +56,10 @@ namespace {
// symbol tables.
// We can't use DialectFoldInterface since the cache may be invalidated by some
// pass changing the referenced MeshOp ops.
-struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
- template <typename... OpRewritePatternArgs>
- MeshShapeFolder(SymbolTableCollection &symbolTableCollection,
- OpRewritePatternArgs &&...opRewritePatternArgs)
- : OpRewritePattern(
- std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
- symbolTableCollection(symbolTableCollection) {}
+struct MeshShapeFolder
+ : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
+ using OpRewritePatternWithSymbolTableCollection::
+ OpRewritePatternWithSymbolTableCollection;
LogicalResult matchAndRewrite(MeshShapeOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
@@ -113,9 +110,6 @@ struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
return success();
}
-
-private:
- SymbolTableCollection &symbolTableCollection;
};
} // namespace
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index b649157a9e46de..21e03d9572c590 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -128,92 +128,24 @@ targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
}
-static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
- int64_t splitTensorAxis,
- int64_t splitCount) {
- SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
- targetShape[splitTensorAxis] =
- shardDimension(targetShape[splitTensorAxis], splitCount);
- return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
-}
-
// Split a replicated tensor along a mesh axis.
// e.g. [[0, 1]] -> [[0, 1, 2]].
// Returns the spmdized target value with its sharding.
-//
-// The implementation is the extract the tensor slice corresponding
-// to the current device.
static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshShardingAttr sourceSharding,
TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
- MLIRContext *ctx = builder.getContext();
- builder.setInsertionPointAfterValue(sourceShard);
-
- Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
-
- Value processIndexAlongAxis =
+ TypedValue<ShapedType> targetShard =
builder
- .create<ProcessMultiIndexOp>(mesh.getSymName(),
- SmallVector<MeshAxis>({splitMeshAxis}))
- .getResult()[0];
-
+ .create<AllScatterOp>(sourceShard, mesh,
+ ArrayRef<MeshAxis>(splitMeshAxis),
+ splitTensorAxis)
+ .ge...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/81218
More information about the llvm-commits
mailing list