[Mlir-commits] [llvm] [mlir] [mlir][mesh] Add all-scatter operation (PR #81218)
Boian Petkantchin
llvmlistbot at llvm.org
Thu Feb 8 19:10:23 PST 2024
https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/81218
>From 5d1e4cdaf9b5722d8322adaf36b4dce1962381e3 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 7 Feb 2024 08:05:53 -0800
Subject: [PATCH 1/2] [mlir][mesh] Add all-scatter operation
This op is the inverse of all-gather. It is useful to have an explicit
concise representation instead of having a blob of slicing logic.
Add lowering for the op that slices from the tensor based on the in-group
process index.
Make resharding generate an all-scatter instead of inserting the slicing logic
directly.
---
mlir/include/mlir/Dialect/Affine/Utils.h | 7 +
mlir/include/mlir/Dialect/Arith/Utils/Utils.h | 17 +++
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 6 +
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 63 ++++++++
.../mlir/Dialect/Mesh/Transforms/Transforms.h | 19 ++-
mlir/include/mlir/IR/Builders.h | 3 +
mlir/lib/Dialect/Affine/Utils/Utils.cpp | 27 ++++
mlir/lib/Dialect/Arith/Utils/Utils.cpp | 19 +++
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 64 ++++++++-
.../Dialect/Mesh/Transforms/CMakeLists.txt | 2 +
.../Mesh/Transforms/Simplifications.cpp | 16 +--
.../Dialect/Mesh/Transforms/Spmdization.cpp | 84 ++---------
.../Dialect/Mesh/Transforms/Transforms.cpp | 136 ++++++++++++++++--
.../Mesh/Transforms/TransformsDetail.h | 35 +++++
mlir/lib/IR/Builders.cpp | 18 +++
.../Dialect/Mesh/all-scatter-op-lowering.mlir | 74 ++++++++++
mlir/test/Dialect/Mesh/canonicalization.mlir | 13 ++
mlir/test/Dialect/Mesh/invalid.mlir | 52 +++++++
mlir/test/Dialect/Mesh/ops.mlir | 24 ++++
.../Dialect/Mesh/resharding-spmdization.mlir | 31 +---
mlir/test/Dialect/Mesh/spmdization.mlir | 14 +-
mlir/test/lib/Dialect/Mesh/CMakeLists.txt | 2 +-
mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp | 79 ++++++++++
.../Mesh/TestProcessMultiIndexOpLowering.cpp | 54 -------
mlir/tools/mlir-opt/mlir-opt.cpp | 4 +-
.../llvm-project-overlay/mlir/BUILD.bazel | 2 +
26 files changed, 673 insertions(+), 192 deletions(-)
create mode 100644 mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
create mode 100644 mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
create mode 100644 mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
delete mode 100644 mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index b3ccbff3002fb1..441ce1952210bb 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -15,12 +15,14 @@
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/IR/OpDefinition.h"
#include <optional>
namespace mlir {
class DominanceInfo;
class Operation;
class PostDominanceInfo;
+class ImplicitLocOpBuilder;
namespace func {
class FuncOp;
@@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
Value linearIndex,
ArrayRef<Value> basis);
+// Generate IR that extracts the linear index form a multi-index according to
+// a shape.
+OpFoldResult linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+ ArrayRef<OpFoldResult> shape,
+ ImplicitLocOpBuilder &builder);
/// Ensure that all operations that could be executed after `start`
/// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 402bd196f0736a..2111a7c5810294 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -20,6 +20,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
+#include "llvm/ADT/ArrayRef.h"
namespace mlir {
@@ -81,6 +82,22 @@ struct ArithBuilder {
OpBuilder &b;
Location loc;
};
+
+namespace arith {
+
+// Build the product of a sequence.
+// If values = (v0, v1, ..., vn) than the returned
+// value is v0 * v1 * ... * vn.
+// All values must have the same type.
+//
+// The version without `resultType` must contain at least one element in values.
+// Then the result will have the same type as the elements in `values`.
+// If `values` is empty in the version with `resultType` returns 1 with type
+// `resultType`.
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+ Type resultType);
+} // namespace arith
} // namespace mlir
#endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9154e6fd803102..fb9425b96e68e2 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -92,6 +92,12 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
return res;
}
+template <typename MeshAxesRange>
+int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
+ return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
+ mesh.getShape());
+}
+
// Get the size of a sharded dimension.
inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index da372706ec724c..7792aac784d4be 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -96,6 +96,7 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
let builders = [
OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef<MeshAxis>":$axes)>,
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
];
}
@@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
let hasCanonicalizer = 1;
}
+def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
+ Pure,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank
+ ]> {
+ let summary = "All-scatter over a device mesh. This is the inverse of all-gather.";
+ let description = [{
+ Scatter along the `scatter_axis` tensor axis.
+ This operation can be taught of as the inverse of all-gather.
+ Technically, it is not required that all processes have the same input tensor.
+ Each process will slice a piece of its local tensor based on its in-group device index.
+ The operation does not communicate data between devices.
+
+ Example:
+ ```mlir
+ mesh.mesh @mesh0(shape = 2x2)
+ ...
+ %1 = mesh.all_scatter %0 on @mesh0 mesh_axes = [1] scatter_axis = 1
+ : tensor<2x4xi8> -> tensor<2x2xi8>
+ ```
+ Input:
+ ```
+ +-------------+
+ | 1 2 5 6 | <- devices (0, 0) and (0, 1)
+ | 3 4 7 8 |
+ +-------------+
+ | 9 10 13 14 | <- devices (1, 0) and (1, 1)
+ | 11 12 15 16 |
+ +-------------+
+ ```
+ Result:
+ ```
+ gather tensor
+ axis 1
+ ------------>
+ +-------+-------+
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
+ | 3 4 | 7 8 |
+ +-------+-------+
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
+ | 11 12 | 15 16 |
+ +-------+-------+
+ ```
+ }];
+ let arguments = !con(commonArgs, (ins
+ AnyNon0RankedTensor:$input,
+ IndexAttr:$scatter_axis
+ ));
+ let results = (outs
+ AnyNon0RankedTensor:$result
+ );
+ let assemblyFormat = [{
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `scatter_axis` `=` $scatter_axis
+ attr-dict `:` type($input) `->` type($result)
+ }];
+ let hasCanonicalizer = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>,
+ OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>
+ ];
+}
+
def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
Pure,
SameOperandsAndResultElementType,
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index 10a965daac71b9..d398bdd65330b1 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -9,17 +9,34 @@
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+
namespace mlir {
class RewritePatternSet;
class SymbolTableCollection;
class DialectRegistry;
+class ImplicitLocOpBuilder;
namespace mesh {
void processMultiIndexOpLoweringPopulatePatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
-
void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry);
+void allScatterOpLoweringPopulatePatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void allScatterOpLoweringRegisterDialects(DialectRegistry ®istry);
+
+void populateAllOpLoweringPatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void registerAllOpLoweringDialects(DialectRegistry ®istry);
+
+TypedValue<IndexType>
+createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+ ImplicitLocOpBuilder &builder);
+
} // namespace mesh
} // namespace mlir
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 2fe1495b2b593b..43b6d2b3841690 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -118,6 +118,9 @@ class Builder {
// supports boolean, integer, and 16-/32-/64-bit float types, and vector or
// ranked tensor of them. Returns null attribute otherwise.
TypedAttr getZeroAttr(Type type);
+ // Returns a 1-valued attribute of the given `type`.
+ // Type constraints are the same as `getZeroAttr`.
+ TypedAttr getOneAttr(Type type);
// Convenience methods for fixed types.
FloatAttr getF16FloatAttr(float value);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 4d4adb94a9fc8d..41bcd2b9f33e6b 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -20,9 +20,11 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <optional>
@@ -1869,3 +1871,28 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
results.push_back(residual);
return results;
}
+
+OpFoldResult
+mlir::affine::linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+ ArrayRef<OpFoldResult> shape,
+ ImplicitLocOpBuilder &builder) {
+ assert(multiIndex.size() == shape.size());
+ SmallVector<AffineExpr> shapeAffine;
+ for (size_t i = 0; i < shape.size(); ++i) {
+ shapeAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
+ }
+
+ SmallVector<AffineExpr> stridesAffine = computeStrides(shapeAffine);
+ SmallVector<OpFoldResult> strides;
+ strides.reserve(stridesAffine.size());
+ llvm::transform(stridesAffine, std::back_inserter(strides),
+ [&builder, &shape](AffineExpr strideExpr) {
+ return affine::makeComposedFoldedAffineApply(
+ builder, builder.getLoc(), strideExpr, shape);
+ });
+
+ auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
+ OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
+ return affine::makeComposedFoldedAffineApply(
+ builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
+}
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index bf274d4ae27ed8..999cdbc5c10008 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "llvm/ADT/SmallBitVector.h"
+#include <numeric>
using namespace mlir;
@@ -262,3 +263,21 @@ Value ArithBuilder::slt(Value lhs, Value rhs) {
Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
}
+
+namespace mlir::arith {
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
+ return createProduct(builder, loc, values, values.front().getType());
+}
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+ Type resultType) {
+ Value one = builder.create<ConstantOp>(loc, resultType,
+ builder.getOneAttr(resultType));
+ ArithBuilder arithBuilder(builder, loc);
+ return std::accumulate(
+ values.begin(), values.end(), one,
+ [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
+}
+
+} // namespace mlir::arith
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index a65b8f2e5a2376..762725d2c56e66 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -252,14 +252,20 @@ MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshOp mesh) {
+ build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
+}
+
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ MeshOp mesh, ArrayRef<MeshAxis> axes) {
build(odsBuilder, odsState,
- SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
- mesh.getSymName(),
- MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
+ SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
+ odsBuilder.getIndexType()),
+ mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
}
void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef mesh, ArrayRef<MeshAxis> axes) {
+ assert(!axes.empty());
build(odsBuilder, odsState,
SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
@@ -590,6 +596,22 @@ static LogicalResult verifyScatterOperandAndResultShape(
return success();
}
+static RankedTensorType scatterResultType(Type operandType, MeshOp mesh,
+ ArrayRef<MeshAxis> meshAxes,
+ int64_t scatterAxis) {
+ RankedTensorType operandRankedTensorType =
+ operandType.cast<RankedTensorType>();
+ DimensionSize operandScatterAxisSize =
+ operandRankedTensorType.getShape()[scatterAxis];
+ SmallVector<int64_t> resultShape =
+ llvm::to_vector(operandRankedTensorType.getShape());
+
+ resultShape[scatterAxis] =
+ operandScatterAxisSize /
+ DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
+ return operandRankedTensorType.clone(resultShape);
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_gather op
//===----------------------------------------------------------------------===//
@@ -625,6 +647,42 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
}
+//===----------------------------------------------------------------------===//
+// mesh.all_scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AllScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return verifyScatterOperandAndResultShape(
+ getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+ mesh.value().getShape());
+}
+
+void AllScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<AllScatterOp>>(context);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
+ int64_t scatterAxis) {
+ Type resultType =
+ scatterResultType(input.getType(), mesh, meshAxes, scatterAxis);
+ build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
+ scatterAxis);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ Type resultType, Value input, StringRef mesh,
+ ArrayRef<MeshAxis> meshAxes, int64_t scatterAxis) {
+ build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
+ APInt(sizeof(scatterAxis) * CHAR_BIT, scatterAxis));
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_to_all op
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index dccb75848c94f0..28af820440076c 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -13,7 +13,9 @@ add_mlir_dialect_library(MLIRMeshTransforms
LINK_LIBS PUBLIC
MLIRAffineDialect
+ MLIRAffineUtils
MLIRArithDialect
+ MLIRArithUtils
MLIRControlFlowDialect
MLIRFuncDialect
MLIRIR
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index c0273cdaef7144..7fcac2312444f3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "TransformsDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -16,7 +17,6 @@
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
-#include <iterator>
#include <numeric>
#include <utility>
@@ -56,13 +56,10 @@ namespace {
// symbol tables.
// We can't use DialectFoldInterface since the cache may be invalidated by some
// pass changing the referenced MeshOp ops.
-struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
- template <typename... OpRewritePatternArgs>
- MeshShapeFolder(SymbolTableCollection &symbolTableCollection,
- OpRewritePatternArgs &&...opRewritePatternArgs)
- : OpRewritePattern(
- std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
- symbolTableCollection(symbolTableCollection) {}
+struct MeshShapeFolder
+ : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
+ using OpRewritePatternWithSymbolTableCollection::
+ OpRewritePatternWithSymbolTableCollection;
LogicalResult matchAndRewrite(MeshShapeOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
@@ -113,9 +110,6 @@ struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
return success();
}
-
-private:
- SymbolTableCollection &symbolTableCollection;
};
} // namespace
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index b649157a9e46de..21e03d9572c590 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -128,92 +128,24 @@ targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
}
-static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
- int64_t splitTensorAxis,
- int64_t splitCount) {
- SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
- targetShape[splitTensorAxis] =
- shardDimension(targetShape[splitTensorAxis], splitCount);
- return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
-}
-
// Split a replicated tensor along a mesh axis.
// e.g. [[0, 1]] -> [[0, 1, 2]].
// Returns the spmdized target value with its sharding.
-//
-// The implementation is the extract the tensor slice corresponding
-// to the current device.
static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshShardingAttr sourceSharding,
TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
- MLIRContext *ctx = builder.getContext();
- builder.setInsertionPointAfterValue(sourceShard);
-
- Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
-
- Value processIndexAlongAxis =
+ TypedValue<ShapedType> targetShard =
builder
- .create<ProcessMultiIndexOp>(mesh.getSymName(),
- SmallVector<MeshAxis>({splitMeshAxis}))
- .getResult()[0];
-
+ .create<AllScatterOp>(sourceShard, mesh,
+ ArrayRef<MeshAxis>(splitMeshAxis),
+ splitTensorAxis)
+ .getResult()
+ .cast<TypedValue<ShapedType>>();
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
- ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
- ShapedType targetShape = targetShapeInSplitLastAxis(
- sourceShard.getType(), splitTensorAxis, mesh.getShape()[splitMeshAxis]);
-
- Value meshAxisSize =
- builder
- .create<MeshShapeOp>(mesh.getSymName(),
- SmallVector<MeshAxis>({splitMeshAxis}))
- .getResult()[0];
-
- Value sourceAxisSize =
- builder.create<tensor::DimOp>(sourceShard, splitTensorAxis);
- Value sourceAxisSizeModMeshAxisSize =
- builder.create<arith::RemUIOp>(sourceAxisSize, meshAxisSize);
- Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
- arith::CmpIPredicate::eq, sourceAxisSizeModMeshAxisSize, zero);
- builder.create<cf::AssertOp>(
- isTargetShapeExactlyDivisible,
- "Sharding a tensor with axis size that is not exactly divisible by the "
- "mesh axis size is not supported.");
- Value targetAxisSize =
- builder.create<arith::DivUIOp>(sourceAxisSize, meshAxisSize);
- Value axisOffset =
- builder.create<arith::MulIOp>(targetAxisSize, processIndexAlongAxis);
- SmallVector<int64_t> staticOffsets(targetShape.getRank(), 0);
- staticOffsets[splitTensorAxis] = ShapedType::kDynamic;
- DenseI64ArrayAttr staticOffsetsAttr =
- DenseI64ArrayAttr::get(ctx, staticOffsets);
- SmallVector<Value> dynamicOffsets(1, axisOffset);
-
- DenseI64ArrayAttr staticSizesAttr =
- DenseI64ArrayAttr::get(ctx, targetShape.getShape());
- SmallVector<Value> dynamicSizes;
- for (int64_t i = 0; i < targetShape.getRank(); ++i) {
- if (ShapedType::isDynamic(staticSizesAttr.asArrayRef()[i])) {
- if (i == splitTensorAxis) {
- dynamicSizes.push_back(targetAxisSize);
- } else {
- Value dimSize = builder.create<tensor::DimOp>(sourceShard, i);
- dynamicSizes.push_back(dimSize);
- }
- }
- }
-
- DenseI64ArrayAttr staticStridesAttr = DenseI64ArrayAttr::get(
- ctx, SmallVector<int64_t>(targetShape.getRank(), 1));
- TypedValue<RankedTensorType> targetShard =
- builder
- .create<tensor::ExtractSliceOp>(
- targetShape, sourceShard, dynamicOffsets, dynamicSizes,
- SmallVector<Value>({}), staticOffsetsAttr, staticSizesAttr,
- staticStridesAttr)
- .getResult();
- return {targetShard.cast<TypedValue<ShapedType>>(), targetSharding};
+ builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
+ return {targetShard, targetSharding};
}
// Detect if the resharding is of type e.g.
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index 03b1d9b3498028..95739c18ce0f36 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -7,12 +7,21 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "TransformsDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
@@ -26,18 +35,14 @@ namespace {
/// Lower `mesh.process_multi_index` into expression using
/// `mesh.process_linear_index` and `mesh.mesh_shape`.
-struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
- template <typename... OpRewritePatternArgs>
- ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection,
- OpRewritePatternArgs &&...opRewritePatternArgs)
- : OpRewritePattern(
- std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
- symbolTableCollection(symbolTableCollection) {}
+struct ProcessMultiIndexOpLowering
+ : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
+ using OpRewritePatternWithSymbolTableCollection::
+ OpRewritePatternWithSymbolTableCollection;
LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
PatternRewriter &rewriter) const override {
- MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
- op.getOperation(), op.getMeshAttr());
+ MeshOp mesh = getMesh(op, symbolTableCollection);
if (!mesh) {
return failure();
}
@@ -46,6 +51,7 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
builder.setInsertionPointAfter(op.getOperation());
Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
+ llvm::errs() << "meshShape.size() = " << meshShape.size() << "\n";
SmallVector<Value> completeMultiIndex =
builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
.getMultiIndex();
@@ -64,9 +70,83 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
return success();
}
+};
+
+struct AllScatterOpLowering
+ : OpRewritePatternWithSymbolTableCollection<AllScatterOp> {
+ using OpRewritePatternWithSymbolTableCollection::
+ OpRewritePatternWithSymbolTableCollection;
+
+ LogicalResult matchAndRewrite(AllScatterOp op,
+ PatternRewriter &rewriter) const override {
+ MeshOp mesh = getMesh(op, symbolTableCollection);
+ if (!mesh) {
+ return failure();
+ }
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ builder.setInsertionPointAfter(op.getOperation());
+
+ Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
+
+ Operation::result_range processInGroupMultiIndex =
+ builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes())
+ .getResults();
+
+ Operation::result_range processGroupShape =
+ builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes())
+ .getResult();
+ Value processGroupSize =
+ createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
+
+ int64_t scatterAxis = op.getScatterAxis().getSExtValue();
+ Value operandScatterAxisSize =
+ builder.create<tensor::DimOp>(op.getOperand(), scatterAxis);
+ Value operandScatterAxisSizeModProcessGroupSize =
+ builder.create<arith::RemUIOp>(operandScatterAxisSize,
+ processGroupSize);
+ Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
+ arith::CmpIPredicate::eq, operandScatterAxisSizeModProcessGroupSize,
+ zero);
+ builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
+ "Scattering a tensor with axis size that is "
+ "not exactly divisible by the "
+ "mesh process group size is not supported.");
+ Value resultScatterAxisSize = builder.create<arith::DivUIOp>(
+ operandScatterAxisSize, processGroupSize);
+ OpFoldResult processInGroupLinearIndex = affine::linearIndexFromShape(
+ llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
+ llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
+
+ // extract slice
+ RankedTensorType operandType =
+ op.getOperand().getType().cast<RankedTensorType>();
+ SmallVector<OpFoldResult> sizes;
+ for (int64_t i = 0; i < operandType.getRank(); ++i) {
+ if (i == scatterAxis) {
+ sizes.emplace_back(resultScatterAxisSize);
+ } else {
+ Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i);
+ sizes.emplace_back(dimSize);
+ }
+ }
+ SmallVector<OpFoldResult> offsets(
+ operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0));
+ offsets[scatterAxis] =
+ ArithBuilder(builder, builder.getLoc())
+ .mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
+ processInGroupLinearIndex),
+ resultScatterAxisSize);
+ SmallVector<OpFoldResult> strides(
+ operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1));
+ Value slice = builder.create<tensor::ExtractSliceOp>(
+ op.getOperand(), offsets, sizes, strides);
+ Value newResult =
+ builder.create<tensor::CastOp>(op.getResult().getType(), slice);
+ rewriter.replaceAllUsesWith(op.getResult(), newResult);
-private:
- SymbolTableCollection &symbolTableCollection;
+ return success();
+ }
};
} // namespace
@@ -81,4 +161,38 @@ void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry) {
registry.insert<affine::AffineDialect, mesh::MeshDialect>();
}
+void allScatterOpLoweringPopulatePatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
+ patterns.add<AllScatterOpLowering>(symbolTableCollection,
+ patterns.getContext());
+}
+
+void allScatterOpLoweringRegisterDialects(DialectRegistry ®istry) {
+ registry.insert<affine::AffineDialect, arith::ArithDialect,
+ cf::ControlFlowDialect, mesh::MeshDialect,
+ tensor::TensorDialect>();
+}
+
+void populateAllOpLoweringPatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
+ processMultiIndexOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+ allScatterOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+}
+
+void registerAllOpLoweringDialects(DialectRegistry ®istry) {
+ processMultiIndexOpLoweringRegisterDialects(registry);
+ allScatterOpLoweringRegisterDialects(registry);
+}
+
+TypedValue<IndexType>
+createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+ ImplicitLocOpBuilder &builder) {
+ Operation::result_range meshShape =
+ builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
+ return arith::createProduct(builder, builder.getLoc(),
+ llvm::to_vector_of<Value>(meshShape),
+ builder.getIndexType())
+ .cast<TypedValue<IndexType>>();
+}
+
} // namespace mlir::mesh
diff --git a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h b/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
new file mode 100644
index 00000000000000..3e3f584caca247
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
@@ -0,0 +1,35 @@
+//===- TransformsDetail.h - -------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+
+namespace mlir {
+namespace mesh {
+
+template <typename Op>
+struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern<Op> {
+ template <typename... OpRewritePatternArgs>
+ OpRewritePatternWithSymbolTableCollection(
+ SymbolTableCollection &symbolTableCollection,
+ OpRewritePatternArgs &&...opRewritePatternArgs)
+ : OpRewritePattern<Op>(
+ std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
+ symbolTableCollection(symbolTableCollection) {}
+
+protected:
+ SymbolTableCollection &symbolTableCollection;
+};
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2e42c4e870716b..18ca3c332e0204 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -346,6 +346,24 @@ TypedAttr Builder::getZeroAttr(Type type) {
return {};
}
+TypedAttr Builder::getOneAttr(Type type) {
+ if (llvm::isa<FloatType>(type))
+ return getFloatAttr(type, 1.0);
+ if (llvm::isa<IndexType>(type))
+ return getIndexAttr(1);
+ if (llvm::dyn_cast<IntegerType>(type))
+ return getIntegerAttr(type,
+ APInt(llvm::cast<IntegerType>(type).getWidth(), 1));
+ if (llvm::isa<RankedTensorType, VectorType>(type)) {
+ auto vtType = llvm::cast<ShapedType>(type);
+ auto element = getOneAttr(vtType.getElementType());
+ if (!element)
+ return {};
+ return DenseElementsAttr::get(vtType, element);
+ }
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// Affine Expressions, Affine Maps, and Integer Sets.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
new file mode 100644
index 00000000000000..8aa3fc5160a001
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
@@ -0,0 +1,74 @@
+// RUN: mlir-opt --split-input-file --test-mesh-all-scatter-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s
+
+mesh.mesh @mesh_1d(shape = ?)
+
+// CHECK-LABEL: func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh
+func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
+ // CHECK: %[[ARG:.*]]: tensor<?xf16>
+ %arg0: tensor<?xf16>
+// CHECK-SAME: -> tensor<?xf16> {
+) -> tensor<?xf16> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+ // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+ // CHECK: %[[TENSOR_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %c0 : tensor<?xf16>
+ // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
+ // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
+ // CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
+ // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
+ // CHECK: %[[SLICE_OFFSET:.*]] = arith.muli %[[PROC_IDX]], %[[RESULT_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][%[[SLICE_OFFSET]]] [%[[RESULT_AXIS_SIZE]]] [1] : tensor<?xf16> to tensor<?xf16>
+ %0 = mesh.all_scatter %arg0 on @mesh_1d mesh_axes = [0] scatter_axis = 0 : tensor<?xf16> -> tensor<?xf16>
+ // CHECK: return %[[RESULT]] : tensor<?xf16>
+ return %0 : tensor<?xf16>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh
+func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
+ // CHECK: %[[ARG:.*]]: tensor<2xf16>
+ %arg0: tensor<2xf16>
+// CHECK-SAME: -> tensor<1xf16> {
+) -> tensor<1xf16> {
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[PROC_IDX]]] [%[[C1]]] [1] : tensor<2xf16> to tensor<?xf16>
+ // CHECK: %[[RESULT:.*]] = tensor.cast %[[SLICE]] : tensor<?xf16> to tensor<1xf16>
+ %0 = mesh.all_scatter %arg0 on @mesh_1d mesh_axes = [0] scatter_axis = 0 : tensor<2xf16> -> tensor<1xf16>
+ // CHECK: return %[[RESULT]] : tensor<1xf16>
+ return %0 : tensor<1xf16>
+}
+
+// -----
+
+// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
+
+mesh.mesh @mesh_4d(shape = ?x?x?x?)
+
+// CHECK-LABEL: func.func @all_scatter_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh
+func.func @all_scatter_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
+ // CHECK: %[[ARG:.*]]: tensor<?x?xf16>
+ %arg0 : tensor<?x?xf16>
+// CHECK-SAME: -> tensor<?x?xf16> {
+) -> tensor<?x?xf16> {
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = mesh.process_multi_index on @mesh_4d axes = [3, 1] : index, index
+ // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = mesh.mesh_shape @mesh_4d axes = [3, 1] : index, index
+ // CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index
+ // CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
+ // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
+ // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
+ // CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
+ // CHECK: %[[RESULT_SCATTER_AXIS_SIZE:.*]] = arith.divui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
+ // CHECK: %[[PROC_IN_GROUP_LINEAR_IDX:.*]] = affine.apply #map()[%[[IN_GROUP_PROC_MULTI_IDX]]#0, %[[PROC_GROUP_SHAPE]]#1, %[[IN_GROUP_PROC_MULTI_IDX]]#1]
+ // CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16>
+ // CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
+ %0 = mesh.all_scatter %arg0 on @mesh_4d mesh_axes = [3, 1] scatter_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
+ // CHECK: return %[[RESULT]] : tensor<?x?xf16>
+ return %0 : tensor<?x?xf16>
+}
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 23c5b253b4c073..4fa9c81f4cd004 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -63,6 +63,19 @@ func.func @all_gather_empty_mesh_axes(
return %0 : tensor<4xf32>
}
+// CHECK-LABEL: func @all_scatter_empty_mesh_axes
+func.func @all_scatter_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.scatter
+ %0 = mesh.all_scatter %arg0 on @mesh0
+ mesh_axes = []
+ scatter_axis = 0
+ : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
// CHECK-LABEL: func @broadcast_empty_mesh_axes
func.func @broadcast_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 3fa3ebd67b15e7..590d548d9c769e 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -316,6 +316,58 @@ func.func @all_gather_invalid_negative_gather_axis(
// -----
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_duplicate_mesh_axis(
+ %arg0 : tensor<?xf32>) -> tensor<?xf32> {
+ // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+ %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0, 0]
+ scatter_axis = 0
+ : tensor<?xf32> -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_invalid_dynamic_dimension(
+ %arg0 : tensor<?xf32>) -> tensor<2xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
+ %0 = mesh.all_scatter %arg0 on @mesh0
+ scatter_axis = 0
+ : tensor<?xf32> -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_invalid_static_dimension_size(
+ %arg0 : tensor<3xf32>) -> tensor<2xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+ %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0]
+ scatter_axis = 0
+ : tensor<3xf32> -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_invalid_operand_static_dimension_size(
+ %arg0 : tensor<4xf32>) -> tensor<?xf32> {
+ // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+ %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0]
+ scatter_axis = 0
+ : tensor<4xf32> -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
func.func @all_to_all_invalid_mesh_symbol(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
// expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 40a8469b264643..69e8545f3cb86f 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -208,6 +208,30 @@ func.func @all_gather_dynamic_dims_in_mesh(
return %0 : tensor<5x?xf32>
}
+// CHECK-LABEL: func @all_scatter_static_dimensions
+func.func @all_scatter_static_dimensions(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+ %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
+ // CHECK-NEXT: mesh.all_scatter %[[ARG]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [2] scatter_axis = 1
+ // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32>
+ %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [2] scatter_axis = 1
+ : tensor<3x4xf32> -> tensor<3x1xf32>
+ return %0 : tensor<3x1xf32>
+}
+
+// CHECK-LABEL: func @all_scatter_dynamic_dimensions
+func.func @all_scatter_dynamic_dimensions(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
+ %arg0 : tensor<?xf32>) -> tensor<?xf32> {
+ // CHECK-NEXT: mesh.all_scatter %[[ARG]]
+ // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+ // CHECK-SAME: : tensor<?xf32> -> tensor<?xf32>
+ %0 = mesh.all_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+ : tensor<?xf32> -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
// CHECK-LABEL: func @all_to_all
func.func @all_to_all(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index f1d686135c28eb..fdf9905a62e99c 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -19,17 +19,9 @@ func.func @split_replicated_tensor_axis(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
%arg0: tensor<3x14xf32>
) -> tensor<3x14xf32> {
- // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
- // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
- // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
- // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
- // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
- // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
- // CHECK: %[[RESULT_TENSOR_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
- // CHECK: %[[RESULT_TENSOR_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
- // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][0, %[[RESULT_TENSOR_AXIS_OFFSET]]] [3, 7] [1, 1] : tensor<3x14xf32> to tensor<3x7xf32>
- // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_TENSOR_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
+ // CHECK: %[[ALL_SCATTER:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 1
+ // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
+ // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SCATTER]] : tensor<3x7xf32> to tensor<3x14xf32>
%0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32>
%1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32>
// CHECK: return %[[RESULT]] : tensor<3x14xf32>
@@ -41,22 +33,11 @@ func.func @split_replicated_tensor_axis_dynamic(
// CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
%arg0: tensor<?x3x?xf32>
) -> tensor<?x3x?xf32> {
- // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[TWO:.*]] = arith.constant 2 : index
- // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d_dynamic axes = [0] : index
- // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d_dynamic axes = [0] : index
- // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
- // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
- // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
- // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
- // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
- // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
- // CHECK: %[[TENSOR_AXIS_2_SIZE:.*]] = tensor.dim %[[ARG]], %[[TWO]] : tensor<?x3x?xf32>
- // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[RESULT_TENSOR_SPLIT_AXIS_OFFSET]], 0, 0]
- // CHECK-SAME: [%[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], 3, %[[TENSOR_AXIS_2_SIZE]]] [1, 1, 1] : tensor<?x3x?xf32> to tensor<?x3x?xf32>
+ // CHECK: %[[RESULT:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] scatter_axis = 0
+ // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
%0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32>
%1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32>
- // CHECK: return %[[RESULT_TENSOR_SLICE]] : tensor<?x3x?xf32>
+ // CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
return %1 : tensor<?x3x?xf32>
}
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 9993c1518e9eae..3aeea169107d66 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -63,9 +63,8 @@ func.func @unary_elementwise_with_resharding(
%arg0: tensor<2xi8>
// CHECK-SAME: -> tensor<2xi8> {
) -> tensor<2xi8> {
- // We don't care about the whole resharding IR, just that it happens.
- // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%{{.*}}] [1] [1]
- // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+ // CHECK: %[[SLICE:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 0
+ // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
%0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
%1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
// CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
@@ -109,9 +108,8 @@ func.func @multiple_chained_ops(
%arg0: tensor<2xi8>
// CHECK-SAME: -> tensor<1xi8> {
) -> tensor<2xi8> {
- // We don't care about the whole resharding IR, just that it happens.
- // CHECK: %[[RESHARD1:.*]] = tensor.extract_slice %[[ARG]][%{{.*}}] [1] [1]
- // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+ // CHECK: %[[RESHARD1:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 0
+ // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
%0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
%1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
// CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
@@ -122,8 +120,8 @@ func.func @multiple_chained_ops(
%4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
// CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
%5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
- // CHECK: %[[RESHARD3:.*]] = tensor.extract_slice %[[ABS2]][%{{.*}}] [1] [1]
- // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+ // CHECK: %[[RESHARD3:.*]] = mesh.all_scatter %[[ABS2]] on @mesh_1d mesh_axes = [0] scatter_axis = 0 :
+ // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
%6 = mesh.shard %5 to <@mesh_1d, [[]]> : tensor<2xi8>
%7 = mesh.shard %6 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
// CHECK: return %[[RESHARD3]] : tensor<1xi8>
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index 00931e6c94fc5f..07e9bb6f9f238a 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,6 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMeshTest
- TestProcessMultiIndexOpLowering.cpp
+ TestOpLowering.cpp
TestReshardingSpmdization.cpp
TestSimplifications.cpp
diff --git a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
new file mode 100644
index 00000000000000..2c9131e666ed69
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
@@ -0,0 +1,79 @@
+//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestAllScatterOpLoweringPass
+ : public PassWrapper<TestAllScatterOpLoweringPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllScatterOpLoweringPass)
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ SymbolTableCollection symbolTableCollection;
+ mesh::allScatterOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+ LogicalResult status =
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ (void)status;
+ assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ mesh::allScatterOpLoweringRegisterDialects(registry);
+ }
+ StringRef getArgument() const final {
+ return "test-mesh-all-scatter-op-lowering";
+ }
+ StringRef getDescription() const final {
+ return "Test lowering of all-scatter.";
+ }
+};
+
+struct TestMultiIndexOpLoweringPass
+ : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ SymbolTableCollection symbolTableCollection;
+ mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
+ symbolTableCollection);
+ LogicalResult status =
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ (void)status;
+ assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ mesh::processMultiIndexOpLoweringRegisterDialects(registry);
+ }
+ StringRef getArgument() const final {
+ return "test-mesh-process-multi-index-op-lowering";
+ }
+ StringRef getDescription() const final {
+ return "Test lowering of mesh.process_multi_index op.";
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestOpLoweringPasses() {
+ PassRegistration<TestAllScatterOpLoweringPass>();
+ PassRegistration<TestMultiIndexOpLoweringPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
deleted file mode 100644
index 0bcc403a2734ed..00000000000000
--- a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
+++ /dev/null
@@ -1,54 +0,0 @@
-//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-using namespace mlir;
-
-namespace {
-struct TestMultiIndexOpLoweringPass
- : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
-
- void runOnOperation() override;
- void getDependentDialects(DialectRegistry ®istry) const override {
- mesh::processMultiIndexOpLoweringRegisterDialects(registry);
- }
- StringRef getArgument() const final {
- return "test-mesh-process-multi-index-op-lowering";
- }
- StringRef getDescription() const final {
- return "Test lowering of mesh.process_multi_index op.";
- }
-};
-} // namespace
-
-void TestMultiIndexOpLoweringPass::runOnOperation() {
- RewritePatternSet patterns(&getContext());
- SymbolTableCollection symbolTableCollection;
- mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
- symbolTableCollection);
- LogicalResult status =
- applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
- (void)status;
- assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
-}
-
-namespace mlir {
-namespace test {
-void registerTestMultiIndexOpLoweringPass() {
- PassRegistration<TestMultiIndexOpLoweringPass>();
-}
-} // namespace test
-} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 428bdd9691e095..a94f42927f3f3b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -120,7 +120,7 @@ void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
void registerTestMeshSimplificationsPass();
void registerTestMeshReshardingSpmdizationPass();
-void registerTestMultiIndexOpLoweringPass();
+void registerTestOpLoweringPasses();
void registerTestNextAccessPass();
void registerTestOneToNTypeConversionPass();
void registerTestOpaqueLoc();
@@ -241,7 +241,7 @@ void registerTestPasses() {
mlir::test::registerTestMathPolynomialApproximationPass();
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
- mlir::test::registerTestMultiIndexOpLoweringPass();
+ mlir::test::registerTestOpLoweringPasses();
mlir::test::registerTestMeshSimplificationsPass();
mlir::test::registerTestMeshReshardingSpmdizationPass();
mlir::test::registerTestNextAccessPass();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 9f2c1f1120066b..8943e91c1f0ceb 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3314,7 +3314,9 @@ cc_library(
includes = ["include"],
deps = [
":AffineDialect",
+ ":AffineUtils",
":ArithDialect",
+ ":ArithUtils",
":ControlFlowDialect",
":DialectUtils",
":FuncDialect",
>From 8cbbb177dc5ba8803e4b068672dc679ef230fe80 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 8 Feb 2024 19:10:01 -0800
Subject: [PATCH 2/2] Remove unused resharding dependencies
---
mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 6 +-----
1 file changed, 1 insertion(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 21e03d9572c590..cd0481d7b87aec 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -8,9 +8,6 @@
#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
@@ -519,8 +516,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
}
void reshardingRegisterDependentDialects(DialectRegistry ®istry) {
- registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect,
- cf::ControlFlowDialect>();
+ registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
}
#define GEN_PASS_DEF_SPMDIZATION
More information about the Mlir-commits
mailing list