[Mlir-commits] [llvm] [mlir] [mlir][mesh] Add all-scatter operation (PR #81218)
Boian Petkantchin
llvmlistbot at llvm.org
Tue Feb 13 14:23:25 PST 2024
https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/81218
>From 5d1e4cdaf9b5722d8322adaf36b4dce1962381e3 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 7 Feb 2024 08:05:53 -0800
Subject: [PATCH 1/3] [mlir][mesh] Add all-scatter operation
This op is the inverse of all-gather. It is useful to have an explicit
concise representation instead of having a blob of slicing logic.
Add lowering for the op that slices from the tensor based on the in-group
process index.
Make resharding generate an all-scatter instead of inserting the slicing logic
directly.
---
mlir/include/mlir/Dialect/Affine/Utils.h | 7 +
mlir/include/mlir/Dialect/Arith/Utils/Utils.h | 17 +++
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 6 +
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 63 ++++++++
.../mlir/Dialect/Mesh/Transforms/Transforms.h | 19 ++-
mlir/include/mlir/IR/Builders.h | 3 +
mlir/lib/Dialect/Affine/Utils/Utils.cpp | 27 ++++
mlir/lib/Dialect/Arith/Utils/Utils.cpp | 19 +++
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 64 ++++++++-
.../Dialect/Mesh/Transforms/CMakeLists.txt | 2 +
.../Mesh/Transforms/Simplifications.cpp | 16 +--
.../Dialect/Mesh/Transforms/Spmdization.cpp | 84 ++---------
.../Dialect/Mesh/Transforms/Transforms.cpp | 136 ++++++++++++++++--
.../Mesh/Transforms/TransformsDetail.h | 35 +++++
mlir/lib/IR/Builders.cpp | 18 +++
.../Dialect/Mesh/all-scatter-op-lowering.mlir | 74 ++++++++++
mlir/test/Dialect/Mesh/canonicalization.mlir | 13 ++
mlir/test/Dialect/Mesh/invalid.mlir | 52 +++++++
mlir/test/Dialect/Mesh/ops.mlir | 24 ++++
.../Dialect/Mesh/resharding-spmdization.mlir | 31 +---
mlir/test/Dialect/Mesh/spmdization.mlir | 14 +-
mlir/test/lib/Dialect/Mesh/CMakeLists.txt | 2 +-
mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp | 79 ++++++++++
.../Mesh/TestProcessMultiIndexOpLowering.cpp | 54 -------
mlir/tools/mlir-opt/mlir-opt.cpp | 4 +-
.../llvm-project-overlay/mlir/BUILD.bazel | 2 +
26 files changed, 673 insertions(+), 192 deletions(-)
create mode 100644 mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
create mode 100644 mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
create mode 100644 mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
delete mode 100644 mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index b3ccbff3002fb1..441ce1952210bb 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -15,12 +15,14 @@
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/IR/OpDefinition.h"
#include <optional>
namespace mlir {
class DominanceInfo;
class Operation;
class PostDominanceInfo;
+class ImplicitLocOpBuilder;
namespace func {
class FuncOp;
@@ -309,6 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
Value linearIndex,
ArrayRef<Value> basis);
+// Generate IR that extracts the linear index form a multi-index according to
+// a shape.
+OpFoldResult linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+ ArrayRef<OpFoldResult> shape,
+ ImplicitLocOpBuilder &builder);
/// Ensure that all operations that could be executed after `start`
/// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
index 402bd196f0736a..2111a7c5810294 100644
--- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h
@@ -20,6 +20,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
+#include "llvm/ADT/ArrayRef.h"
namespace mlir {
@@ -81,6 +82,22 @@ struct ArithBuilder {
OpBuilder &b;
Location loc;
};
+
+namespace arith {
+
+// Build the product of a sequence.
+// If values = (v0, v1, ..., vn) than the returned
+// value is v0 * v1 * ... * vn.
+// All values must have the same type.
+//
+// The version without `resultType` must contain at least one element in values.
+// Then the result will have the same type as the elements in `values`.
+// If `values` is empty in the version with `resultType` returns 1 with type
+// `resultType`.
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+ Type resultType);
+} // namespace arith
} // namespace mlir
#endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9154e6fd803102..fb9425b96e68e2 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -92,6 +92,12 @@ int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
return res;
}
+template <typename MeshAxesRange>
+int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
+ return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
+ mesh.getShape());
+}
+
// Get the size of a sharded dimension.
inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index da372706ec724c..7792aac784d4be 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -96,6 +96,7 @@ def Mesh_MeshShapeOp : Mesh_Op<"mesh_shape", [
let builders = [
OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh)>,
+ OpBuilder<(ins "::mlir::mesh::MeshOp":$mesh, "ArrayRef<MeshAxis>":$axes)>,
OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
];
}
@@ -341,6 +342,68 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
let hasCanonicalizer = 1;
}
+def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
+ Pure,
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultRank
+ ]> {
+ let summary = "All-scatter over a device mesh. This is the inverse of all-gather.";
+ let description = [{
+ Scatter along the `scatter_axis` tensor axis.
+ This operation can be taught of as the inverse of all-gather.
+ Technically, it is not required that all processes have the same input tensor.
+ Each process will slice a piece of its local tensor based on its in-group device index.
+ The operation does not communicate data between devices.
+
+ Example:
+ ```mlir
+ mesh.mesh @mesh0(shape = 2x2)
+ ...
+ %1 = mesh.all_scatter %0 on @mesh0 mesh_axes = [1] scatter_axis = 1
+ : tensor<2x4xi8> -> tensor<2x2xi8>
+ ```
+ Input:
+ ```
+ +-------------+
+ | 1 2 5 6 | <- devices (0, 0) and (0, 1)
+ | 3 4 7 8 |
+ +-------------+
+ | 9 10 13 14 | <- devices (1, 0) and (1, 1)
+ | 11 12 15 16 |
+ +-------------+
+ ```
+ Result:
+ ```
+ gather tensor
+ axis 1
+ ------------>
+ +-------+-------+
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
+ | 3 4 | 7 8 |
+ +-------+-------+
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
+ | 11 12 | 15 16 |
+ +-------+-------+
+ ```
+ }];
+ let arguments = !con(commonArgs, (ins
+ AnyNon0RankedTensor:$input,
+ IndexAttr:$scatter_axis
+ ));
+ let results = (outs
+ AnyNon0RankedTensor:$result
+ );
+ let assemblyFormat = [{
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `scatter_axis` `=` $scatter_axis
+ attr-dict `:` type($input) `->` type($result)
+ }];
+ let hasCanonicalizer = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>,
+ OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>
+ ];
+}
+
def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
Pure,
SameOperandsAndResultElementType,
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index 10a965daac71b9..d398bdd65330b1 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -9,17 +9,34 @@
#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMS_H
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+
namespace mlir {
class RewritePatternSet;
class SymbolTableCollection;
class DialectRegistry;
+class ImplicitLocOpBuilder;
namespace mesh {
void processMultiIndexOpLoweringPopulatePatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
-
void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry);
+void allScatterOpLoweringPopulatePatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void allScatterOpLoweringRegisterDialects(DialectRegistry ®istry);
+
+void populateAllOpLoweringPatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
+void registerAllOpLoweringDialects(DialectRegistry ®istry);
+
+TypedValue<IndexType>
+createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+ ImplicitLocOpBuilder &builder);
+
} // namespace mesh
} // namespace mlir
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 2fe1495b2b593b..43b6d2b3841690 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -118,6 +118,9 @@ class Builder {
// supports boolean, integer, and 16-/32-/64-bit float types, and vector or
// ranked tensor of them. Returns null attribute otherwise.
TypedAttr getZeroAttr(Type type);
+ // Returns a 1-valued attribute of the given `type`.
+ // Type constraints are the same as `getZeroAttr`.
+ TypedAttr getOneAttr(Type type);
// Convenience methods for fixed types.
FloatAttr getF16FloatAttr(float value);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 4d4adb94a9fc8d..41bcd2b9f33e6b 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -20,9 +20,11 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <optional>
@@ -1869,3 +1871,28 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
results.push_back(residual);
return results;
}
+
+OpFoldResult
+mlir::affine::linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
+ ArrayRef<OpFoldResult> shape,
+ ImplicitLocOpBuilder &builder) {
+ assert(multiIndex.size() == shape.size());
+ SmallVector<AffineExpr> shapeAffine;
+ for (size_t i = 0; i < shape.size(); ++i) {
+ shapeAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
+ }
+
+ SmallVector<AffineExpr> stridesAffine = computeStrides(shapeAffine);
+ SmallVector<OpFoldResult> strides;
+ strides.reserve(stridesAffine.size());
+ llvm::transform(stridesAffine, std::back_inserter(strides),
+ [&builder, &shape](AffineExpr strideExpr) {
+ return affine::makeComposedFoldedAffineApply(
+ builder, builder.getLoc(), strideExpr, shape);
+ });
+
+ auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
+ OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
+ return affine::makeComposedFoldedAffineApply(
+ builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
+}
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index bf274d4ae27ed8..999cdbc5c10008 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "llvm/ADT/SmallBitVector.h"
+#include <numeric>
using namespace mlir;
@@ -262,3 +263,21 @@ Value ArithBuilder::slt(Value lhs, Value rhs) {
Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) {
return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
}
+
+namespace mlir::arith {
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) {
+ return createProduct(builder, loc, values, values.front().getType());
+}
+
+Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
+ Type resultType) {
+ Value one = builder.create<ConstantOp>(loc, resultType,
+ builder.getOneAttr(resultType));
+ ArithBuilder arithBuilder(builder, loc);
+ return std::accumulate(
+ values.begin(), values.end(), one,
+ [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
+}
+
+} // namespace mlir::arith
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index a65b8f2e5a2376..762725d2c56e66 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -252,14 +252,20 @@ MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
MeshOp mesh) {
+ build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
+}
+
+void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ MeshOp mesh, ArrayRef<MeshAxis> axes) {
build(odsBuilder, odsState,
- SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
- mesh.getSymName(),
- MeshAxesAttr::get(odsBuilder.getContext(), SmallVector<MeshAxis>()));
+ SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
+ odsBuilder.getIndexType()),
+ mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
}
void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
StringRef mesh, ArrayRef<MeshAxis> axes) {
+ assert(!axes.empty());
build(odsBuilder, odsState,
SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
MeshAxesAttr::get(odsBuilder.getContext(), axes));
@@ -590,6 +596,22 @@ static LogicalResult verifyScatterOperandAndResultShape(
return success();
}
+static RankedTensorType scatterResultType(Type operandType, MeshOp mesh,
+ ArrayRef<MeshAxis> meshAxes,
+ int64_t scatterAxis) {
+ RankedTensorType operandRankedTensorType =
+ operandType.cast<RankedTensorType>();
+ DimensionSize operandScatterAxisSize =
+ operandRankedTensorType.getShape()[scatterAxis];
+ SmallVector<int64_t> resultShape =
+ llvm::to_vector(operandRankedTensorType.getShape());
+
+ resultShape[scatterAxis] =
+ operandScatterAxisSize /
+ DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
+ return operandRankedTensorType.clone(resultShape);
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_gather op
//===----------------------------------------------------------------------===//
@@ -625,6 +647,42 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
}
+//===----------------------------------------------------------------------===//
+// mesh.all_scatter op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AllScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ return verifyScatterOperandAndResultShape(
+ getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+ mesh.value().getShape());
+}
+
+void AllScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<AllScatterOp>>(context);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
+ int64_t scatterAxis) {
+ Type resultType =
+ scatterResultType(input.getType(), mesh, meshAxes, scatterAxis);
+ build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
+ scatterAxis);
+}
+
+void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ Type resultType, Value input, StringRef mesh,
+ ArrayRef<MeshAxis> meshAxes, int64_t scatterAxis) {
+ build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
+ APInt(sizeof(scatterAxis) * CHAR_BIT, scatterAxis));
+}
+
//===----------------------------------------------------------------------===//
// mesh.all_to_all op
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index dccb75848c94f0..28af820440076c 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -13,7 +13,9 @@ add_mlir_dialect_library(MLIRMeshTransforms
LINK_LIBS PUBLIC
MLIRAffineDialect
+ MLIRAffineUtils
MLIRArithDialect
+ MLIRArithUtils
MLIRControlFlowDialect
MLIRFuncDialect
MLIRIR
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index c0273cdaef7144..7fcac2312444f3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "TransformsDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -16,7 +17,6 @@
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
-#include <iterator>
#include <numeric>
#include <utility>
@@ -56,13 +56,10 @@ namespace {
// symbol tables.
// We can't use DialectFoldInterface since the cache may be invalidated by some
// pass changing the referenced MeshOp ops.
-struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
- template <typename... OpRewritePatternArgs>
- MeshShapeFolder(SymbolTableCollection &symbolTableCollection,
- OpRewritePatternArgs &&...opRewritePatternArgs)
- : OpRewritePattern(
- std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
- symbolTableCollection(symbolTableCollection) {}
+struct MeshShapeFolder
+ : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
+ using OpRewritePatternWithSymbolTableCollection::
+ OpRewritePatternWithSymbolTableCollection;
LogicalResult matchAndRewrite(MeshShapeOp op,
PatternRewriter &rewriter) const override {
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
@@ -113,9 +110,6 @@ struct MeshShapeFolder : OpRewritePattern<MeshShapeOp> {
return success();
}
-
-private:
- SymbolTableCollection &symbolTableCollection;
};
} // namespace
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index b649157a9e46de..21e03d9572c590 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -128,92 +128,24 @@ targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
}
-static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
- int64_t splitTensorAxis,
- int64_t splitCount) {
- SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
- targetShape[splitTensorAxis] =
- shardDimension(targetShape[splitTensorAxis], splitCount);
- return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
-}
-
// Split a replicated tensor along a mesh axis.
// e.g. [[0, 1]] -> [[0, 1, 2]].
// Returns the spmdized target value with its sharding.
-//
-// The implementation is the extract the tensor slice corresponding
-// to the current device.
static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshShardingAttr sourceSharding,
TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
- MLIRContext *ctx = builder.getContext();
- builder.setInsertionPointAfterValue(sourceShard);
-
- Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
-
- Value processIndexAlongAxis =
+ TypedValue<ShapedType> targetShard =
builder
- .create<ProcessMultiIndexOp>(mesh.getSymName(),
- SmallVector<MeshAxis>({splitMeshAxis}))
- .getResult()[0];
-
+ .create<AllScatterOp>(sourceShard, mesh,
+ ArrayRef<MeshAxis>(splitMeshAxis),
+ splitTensorAxis)
+ .getResult()
+ .cast<TypedValue<ShapedType>>();
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
- ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
- ShapedType targetShape = targetShapeInSplitLastAxis(
- sourceShard.getType(), splitTensorAxis, mesh.getShape()[splitMeshAxis]);
-
- Value meshAxisSize =
- builder
- .create<MeshShapeOp>(mesh.getSymName(),
- SmallVector<MeshAxis>({splitMeshAxis}))
- .getResult()[0];
-
- Value sourceAxisSize =
- builder.create<tensor::DimOp>(sourceShard, splitTensorAxis);
- Value sourceAxisSizeModMeshAxisSize =
- builder.create<arith::RemUIOp>(sourceAxisSize, meshAxisSize);
- Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
- arith::CmpIPredicate::eq, sourceAxisSizeModMeshAxisSize, zero);
- builder.create<cf::AssertOp>(
- isTargetShapeExactlyDivisible,
- "Sharding a tensor with axis size that is not exactly divisible by the "
- "mesh axis size is not supported.");
- Value targetAxisSize =
- builder.create<arith::DivUIOp>(sourceAxisSize, meshAxisSize);
- Value axisOffset =
- builder.create<arith::MulIOp>(targetAxisSize, processIndexAlongAxis);
- SmallVector<int64_t> staticOffsets(targetShape.getRank(), 0);
- staticOffsets[splitTensorAxis] = ShapedType::kDynamic;
- DenseI64ArrayAttr staticOffsetsAttr =
- DenseI64ArrayAttr::get(ctx, staticOffsets);
- SmallVector<Value> dynamicOffsets(1, axisOffset);
-
- DenseI64ArrayAttr staticSizesAttr =
- DenseI64ArrayAttr::get(ctx, targetShape.getShape());
- SmallVector<Value> dynamicSizes;
- for (int64_t i = 0; i < targetShape.getRank(); ++i) {
- if (ShapedType::isDynamic(staticSizesAttr.asArrayRef()[i])) {
- if (i == splitTensorAxis) {
- dynamicSizes.push_back(targetAxisSize);
- } else {
- Value dimSize = builder.create<tensor::DimOp>(sourceShard, i);
- dynamicSizes.push_back(dimSize);
- }
- }
- }
-
- DenseI64ArrayAttr staticStridesAttr = DenseI64ArrayAttr::get(
- ctx, SmallVector<int64_t>(targetShape.getRank(), 1));
- TypedValue<RankedTensorType> targetShard =
- builder
- .create<tensor::ExtractSliceOp>(
- targetShape, sourceShard, dynamicOffsets, dynamicSizes,
- SmallVector<Value>({}), staticOffsetsAttr, staticSizesAttr,
- staticStridesAttr)
- .getResult();
- return {targetShard.cast<TypedValue<ShapedType>>(), targetSharding};
+ builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
+ return {targetShard, targetSharding};
}
// Detect if the resharding is of type e.g.
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index 03b1d9b3498028..95739c18ce0f36 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -7,12 +7,21 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "TransformsDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
@@ -26,18 +35,14 @@ namespace {
/// Lower `mesh.process_multi_index` into expression using
/// `mesh.process_linear_index` and `mesh.mesh_shape`.
-struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
- template <typename... OpRewritePatternArgs>
- ProcessMultiIndexOpLowering(SymbolTableCollection &symbolTableCollection,
- OpRewritePatternArgs &&...opRewritePatternArgs)
- : OpRewritePattern(
- std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
- symbolTableCollection(symbolTableCollection) {}
+struct ProcessMultiIndexOpLowering
+ : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
+ using OpRewritePatternWithSymbolTableCollection::
+ OpRewritePatternWithSymbolTableCollection;
LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
PatternRewriter &rewriter) const override {
- MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
- op.getOperation(), op.getMeshAttr());
+ MeshOp mesh = getMesh(op, symbolTableCollection);
if (!mesh) {
return failure();
}
@@ -46,6 +51,7 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
builder.setInsertionPointAfter(op.getOperation());
Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
+ llvm::errs() << "meshShape.size() = " << meshShape.size() << "\n";
SmallVector<Value> completeMultiIndex =
builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
.getMultiIndex();
@@ -64,9 +70,83 @@ struct ProcessMultiIndexOpLowering : OpRewritePattern<ProcessMultiIndexOp> {
rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
return success();
}
+};
+
+struct AllScatterOpLowering
+ : OpRewritePatternWithSymbolTableCollection<AllScatterOp> {
+ using OpRewritePatternWithSymbolTableCollection::
+ OpRewritePatternWithSymbolTableCollection;
+
+ LogicalResult matchAndRewrite(AllScatterOp op,
+ PatternRewriter &rewriter) const override {
+ MeshOp mesh = getMesh(op, symbolTableCollection);
+ if (!mesh) {
+ return failure();
+ }
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ builder.setInsertionPointAfter(op.getOperation());
+
+ Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
+
+ Operation::result_range processInGroupMultiIndex =
+ builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes())
+ .getResults();
+
+ Operation::result_range processGroupShape =
+ builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes())
+ .getResult();
+ Value processGroupSize =
+ createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
+
+ int64_t scatterAxis = op.getScatterAxis().getSExtValue();
+ Value operandScatterAxisSize =
+ builder.create<tensor::DimOp>(op.getOperand(), scatterAxis);
+ Value operandScatterAxisSizeModProcessGroupSize =
+ builder.create<arith::RemUIOp>(operandScatterAxisSize,
+ processGroupSize);
+ Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
+ arith::CmpIPredicate::eq, operandScatterAxisSizeModProcessGroupSize,
+ zero);
+ builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
+ "Scattering a tensor with axis size that is "
+ "not exactly divisible by the "
+ "mesh process group size is not supported.");
+ Value resultScatterAxisSize = builder.create<arith::DivUIOp>(
+ operandScatterAxisSize, processGroupSize);
+ OpFoldResult processInGroupLinearIndex = affine::linearIndexFromShape(
+ llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
+ llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
+
+ // extract slice
+ RankedTensorType operandType =
+ op.getOperand().getType().cast<RankedTensorType>();
+ SmallVector<OpFoldResult> sizes;
+ for (int64_t i = 0; i < operandType.getRank(); ++i) {
+ if (i == scatterAxis) {
+ sizes.emplace_back(resultScatterAxisSize);
+ } else {
+ Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i);
+ sizes.emplace_back(dimSize);
+ }
+ }
+ SmallVector<OpFoldResult> offsets(
+ operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0));
+ offsets[scatterAxis] =
+ ArithBuilder(builder, builder.getLoc())
+ .mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
+ processInGroupLinearIndex),
+ resultScatterAxisSize);
+ SmallVector<OpFoldResult> strides(
+ operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1));
+ Value slice = builder.create<tensor::ExtractSliceOp>(
+ op.getOperand(), offsets, sizes, strides);
+ Value newResult =
+ builder.create<tensor::CastOp>(op.getResult().getType(), slice);
+ rewriter.replaceAllUsesWith(op.getResult(), newResult);
-private:
- SymbolTableCollection &symbolTableCollection;
+ return success();
+ }
};
} // namespace
@@ -81,4 +161,38 @@ void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry) {
registry.insert<affine::AffineDialect, mesh::MeshDialect>();
}
+void allScatterOpLoweringPopulatePatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
+ patterns.add<AllScatterOpLowering>(symbolTableCollection,
+ patterns.getContext());
+}
+
+void allScatterOpLoweringRegisterDialects(DialectRegistry ®istry) {
+ registry.insert<affine::AffineDialect, arith::ArithDialect,
+ cf::ControlFlowDialect, mesh::MeshDialect,
+ tensor::TensorDialect>();
+}
+
+void populateAllOpLoweringPatterns(
+ RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
+ processMultiIndexOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+ allScatterOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+}
+
+void registerAllOpLoweringDialects(DialectRegistry ®istry) {
+ processMultiIndexOpLoweringRegisterDialects(registry);
+ allScatterOpLoweringRegisterDialects(registry);
+}
+
+TypedValue<IndexType>
+createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
+ ImplicitLocOpBuilder &builder) {
+ Operation::result_range meshShape =
+ builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
+ return arith::createProduct(builder, builder.getLoc(),
+ llvm::to_vector_of<Value>(meshShape),
+ builder.getIndexType())
+ .cast<TypedValue<IndexType>>();
+}
+
} // namespace mlir::mesh
diff --git a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h b/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
new file mode 100644
index 00000000000000..3e3f584caca247
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h
@@ -0,0 +1,35 @@
+//===- TransformsDetail.h - -------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+
+namespace mlir {
+namespace mesh {
+
+template <typename Op>
+struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern<Op> {
+ template <typename... OpRewritePatternArgs>
+ OpRewritePatternWithSymbolTableCollection(
+ SymbolTableCollection &symbolTableCollection,
+ OpRewritePatternArgs &&...opRewritePatternArgs)
+ : OpRewritePattern<Op>(
+ std::forward<OpRewritePatternArgs...>(opRewritePatternArgs)...),
+ symbolTableCollection(symbolTableCollection) {}
+
+protected:
+ SymbolTableCollection &symbolTableCollection;
+};
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2e42c4e870716b..18ca3c332e0204 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -346,6 +346,24 @@ TypedAttr Builder::getZeroAttr(Type type) {
return {};
}
+TypedAttr Builder::getOneAttr(Type type) {
+ if (llvm::isa<FloatType>(type))
+ return getFloatAttr(type, 1.0);
+ if (llvm::isa<IndexType>(type))
+ return getIndexAttr(1);
+ if (llvm::dyn_cast<IntegerType>(type))
+ return getIntegerAttr(type,
+ APInt(llvm::cast<IntegerType>(type).getWidth(), 1));
+ if (llvm::isa<RankedTensorType, VectorType>(type)) {
+ auto vtType = llvm::cast<ShapedType>(type);
+ auto element = getOneAttr(vtType.getElementType());
+ if (!element)
+ return {};
+ return DenseElementsAttr::get(vtType, element);
+ }
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// Affine Expressions, Affine Maps, and Integer Sets.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
new file mode 100644
index 00000000000000..8aa3fc5160a001
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
@@ -0,0 +1,74 @@
+// RUN: mlir-opt --split-input-file --test-mesh-all-scatter-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s
+
+mesh.mesh @mesh_1d(shape = ?)
+
+// CHECK-LABEL: func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh
+func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
+ // CHECK: %[[ARG:.*]]: tensor<?xf16>
+ %arg0: tensor<?xf16>
+// CHECK-SAME: -> tensor<?xf16> {
+) -> tensor<?xf16> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+ // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+ // CHECK: %[[TENSOR_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %c0 : tensor<?xf16>
+ // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
+ // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
+ // CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
+ // CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
+ // CHECK: %[[SLICE_OFFSET:.*]] = arith.muli %[[PROC_IDX]], %[[RESULT_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][%[[SLICE_OFFSET]]] [%[[RESULT_AXIS_SIZE]]] [1] : tensor<?xf16> to tensor<?xf16>
+ %0 = mesh.all_scatter %arg0 on @mesh_1d mesh_axes = [0] scatter_axis = 0 : tensor<?xf16> -> tensor<?xf16>
+ // CHECK: return %[[RESULT]] : tensor<?xf16>
+ return %0 : tensor<?xf16>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh
+func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
+ // CHECK: %[[ARG:.*]]: tensor<2xf16>
+ %arg0: tensor<2xf16>
+// CHECK-SAME: -> tensor<1xf16> {
+) -> tensor<1xf16> {
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[PROC_IDX]]] [%[[C1]]] [1] : tensor<2xf16> to tensor<?xf16>
+ // CHECK: %[[RESULT:.*]] = tensor.cast %[[SLICE]] : tensor<?xf16> to tensor<1xf16>
+ %0 = mesh.all_scatter %arg0 on @mesh_1d mesh_axes = [0] scatter_axis = 0 : tensor<2xf16> -> tensor<1xf16>
+ // CHECK: return %[[RESULT]] : tensor<1xf16>
+ return %0 : tensor<1xf16>
+}
+
+// -----
+
+// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
+
+mesh.mesh @mesh_4d(shape = ?x?x?x?)
+
+// CHECK-LABEL: func.func @all_scatter_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh
+func.func @all_scatter_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
+ // CHECK: %[[ARG:.*]]: tensor<?x?xf16>
+ %arg0 : tensor<?x?xf16>
+// CHECK-SAME: -> tensor<?x?xf16> {
+) -> tensor<?x?xf16> {
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[IN_GROUP_PROC_MULTI_IDX:.*]]:2 = mesh.process_multi_index on @mesh_4d axes = [3, 1] : index, index
+ // CHECK-DAG: %[[PROC_GROUP_SHAPE:.*]]:2 = mesh.mesh_shape @mesh_4d axes = [3, 1] : index, index
+ // CHECK: %[[PROC_GROUP_SIZE:.*]] = arith.muli %[[PROC_GROUP_SHAPE]]#0, %[[PROC_GROUP_SHAPE]]#1 : index
+ // CHECK: %[[SCATTER_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf16>
+ // CHECK: %[[AXIS_SIZE_CHECK_REMINDER:.*]] = arith.remui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
+ // CHECK: %[[AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[AXIS_SIZE_CHECK_REMINDER]], %[[C0]] : index
+ // CHECK: cf.assert %[[AXIS_SIZE_CHECK]]
+ // CHECK: %[[RESULT_SCATTER_AXIS_SIZE:.*]] = arith.divui %[[SCATTER_AXIS_SIZE]], %[[PROC_GROUP_SIZE]] : index
+ // CHECK: %[[PROC_IN_GROUP_LINEAR_IDX:.*]] = affine.apply #map()[%[[IN_GROUP_PROC_MULTI_IDX]]#0, %[[PROC_GROUP_SHAPE]]#1, %[[IN_GROUP_PROC_MULTI_IDX]]#1]
+ // CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16>
+ // CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
+ %0 = mesh.all_scatter %arg0 on @mesh_4d mesh_axes = [3, 1] scatter_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
+ // CHECK: return %[[RESULT]] : tensor<?x?xf16>
+ return %0 : tensor<?x?xf16>
+}
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 23c5b253b4c073..4fa9c81f4cd004 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -63,6 +63,19 @@ func.func @all_gather_empty_mesh_axes(
return %0 : tensor<4xf32>
}
+// CHECK-LABEL: func @all_scatter_empty_mesh_axes
+func.func @all_scatter_empty_mesh_axes(
+// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
+ %arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// CHECK-NOT: mesh.scatter
+ %0 = mesh.all_scatter %arg0 on @mesh0
+ mesh_axes = []
+ scatter_axis = 0
+ : tensor<4xf32> -> tensor<4xf32>
+// CHECK: return %[[ARG]]
+ return %0 : tensor<4xf32>
+}
+
// CHECK-LABEL: func @broadcast_empty_mesh_axes
func.func @broadcast_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 3fa3ebd67b15e7..590d548d9c769e 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -316,6 +316,58 @@ func.func @all_gather_invalid_negative_gather_axis(
// -----
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_duplicate_mesh_axis(
+ %arg0 : tensor<?xf32>) -> tensor<?xf32> {
+ // expected-error at +1 {{Mesh axes contains duplicate elements.}}
+ %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0, 0]
+ scatter_axis = 0
+ : tensor<?xf32> -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_invalid_dynamic_dimension(
+ %arg0 : tensor<?xf32>) -> tensor<2xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
+ %0 = mesh.all_scatter %arg0 on @mesh0
+ scatter_axis = 0
+ : tensor<?xf32> -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_invalid_static_dimension_size(
+ %arg0 : tensor<3xf32>) -> tensor<2xf32> {
+ // expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
+ %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0]
+ scatter_axis = 0
+ : tensor<3xf32> -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
+
+// -----
+
+mesh.mesh @mesh0(shape = 3)
+
+func.func @all_scatter_invalid_operand_static_dimension_size(
+ %arg0 : tensor<4xf32>) -> tensor<?xf32> {
+ // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+ %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0]
+ scatter_axis = 0
+ : tensor<4xf32> -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
func.func @all_to_all_invalid_mesh_symbol(
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
// expected-error at +1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 40a8469b264643..69e8545f3cb86f 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -208,6 +208,30 @@ func.func @all_gather_dynamic_dims_in_mesh(
return %0 : tensor<5x?xf32>
}
+// CHECK-LABEL: func @all_scatter_static_dimensions
+func.func @all_scatter_static_dimensions(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
+ %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
+ // CHECK-NEXT: mesh.all_scatter %[[ARG]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [2] scatter_axis = 1
+ // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32>
+ %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [2] scatter_axis = 1
+ : tensor<3x4xf32> -> tensor<3x1xf32>
+ return %0 : tensor<3x1xf32>
+}
+
+// CHECK-LABEL: func @all_scatter_dynamic_dimensions
+func.func @all_scatter_dynamic_dimensions(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
+ %arg0 : tensor<?xf32>) -> tensor<?xf32> {
+ // CHECK-NEXT: mesh.all_scatter %[[ARG]]
+ // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+ // CHECK-SAME: : tensor<?xf32> -> tensor<?xf32>
+ %0 = mesh.all_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+ : tensor<?xf32> -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
// CHECK-LABEL: func @all_to_all
func.func @all_to_all(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index f1d686135c28eb..fdf9905a62e99c 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -19,17 +19,9 @@ func.func @split_replicated_tensor_axis(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
%arg0: tensor<3x14xf32>
) -> tensor<3x14xf32> {
- // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
- // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
- // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
- // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
- // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
- // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
- // CHECK: %[[RESULT_TENSOR_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
- // CHECK: %[[RESULT_TENSOR_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
- // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][0, %[[RESULT_TENSOR_AXIS_OFFSET]]] [3, 7] [1, 1] : tensor<3x14xf32> to tensor<3x7xf32>
- // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_TENSOR_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
+ // CHECK: %[[ALL_SCATTER:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 1
+ // CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
+ // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SCATTER]] : tensor<3x7xf32> to tensor<3x14xf32>
%0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32>
%1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32>
// CHECK: return %[[RESULT]] : tensor<3x14xf32>
@@ -41,22 +33,11 @@ func.func @split_replicated_tensor_axis_dynamic(
// CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
%arg0: tensor<?x3x?xf32>
) -> tensor<?x3x?xf32> {
- // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
- // CHECK-DAG: %[[TWO:.*]] = arith.constant 2 : index
- // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_multi_index on @mesh_1d_dynamic axes = [0] : index
- // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.mesh_shape @mesh_1d_dynamic axes = [0] : index
- // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
- // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
- // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
- // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
- // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
- // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
- // CHECK: %[[TENSOR_AXIS_2_SIZE:.*]] = tensor.dim %[[ARG]], %[[TWO]] : tensor<?x3x?xf32>
- // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[RESULT_TENSOR_SPLIT_AXIS_OFFSET]], 0, 0]
- // CHECK-SAME: [%[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], 3, %[[TENSOR_AXIS_2_SIZE]]] [1, 1, 1] : tensor<?x3x?xf32> to tensor<?x3x?xf32>
+ // CHECK: %[[RESULT:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] scatter_axis = 0
+ // CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
%0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32>
%1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32>
- // CHECK: return %[[RESULT_TENSOR_SLICE]] : tensor<?x3x?xf32>
+ // CHECK: return %[[RESULT]] : tensor<?x3x?xf32>
return %1 : tensor<?x3x?xf32>
}
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 9993c1518e9eae..3aeea169107d66 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -63,9 +63,8 @@ func.func @unary_elementwise_with_resharding(
%arg0: tensor<2xi8>
// CHECK-SAME: -> tensor<2xi8> {
) -> tensor<2xi8> {
- // We don't care about the whole resharding IR, just that it happens.
- // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%{{.*}}] [1] [1]
- // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+ // CHECK: %[[SLICE:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 0
+ // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
%0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
%1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
// CHECK: %[[ABS:.*]] = tosa.abs %[[SLICE]] : (tensor<1xi8>) -> tensor<1xi8>
@@ -109,9 +108,8 @@ func.func @multiple_chained_ops(
%arg0: tensor<2xi8>
// CHECK-SAME: -> tensor<1xi8> {
) -> tensor<2xi8> {
- // We don't care about the whole resharding IR, just that it happens.
- // CHECK: %[[RESHARD1:.*]] = tensor.extract_slice %[[ARG]][%{{.*}}] [1] [1]
- // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+ // CHECK: %[[RESHARD1:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 0
+ // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
%0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
%1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
// CHECK: %[[ABS1:.*]] = tosa.abs %[[RESHARD1]] : (tensor<1xi8>) -> tensor<1xi8>
@@ -122,8 +120,8 @@ func.func @multiple_chained_ops(
%4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
// CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
%5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
- // CHECK: %[[RESHARD3:.*]] = tensor.extract_slice %[[ABS2]][%{{.*}}] [1] [1]
- // CHECK-SAME: tensor<2xi8> to tensor<1xi8>
+ // CHECK: %[[RESHARD3:.*]] = mesh.all_scatter %[[ABS2]] on @mesh_1d mesh_axes = [0] scatter_axis = 0 :
+ // CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
%6 = mesh.shard %5 to <@mesh_1d, [[]]> : tensor<2xi8>
%7 = mesh.shard %6 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
// CHECK: return %[[RESHARD3]] : tensor<1xi8>
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index 00931e6c94fc5f..07e9bb6f9f238a 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,6 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMeshTest
- TestProcessMultiIndexOpLowering.cpp
+ TestOpLowering.cpp
TestReshardingSpmdization.cpp
TestSimplifications.cpp
diff --git a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
new file mode 100644
index 00000000000000..2c9131e666ed69
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
@@ -0,0 +1,79 @@
+//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestAllScatterOpLoweringPass
+ : public PassWrapper<TestAllScatterOpLoweringPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllScatterOpLoweringPass)
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ SymbolTableCollection symbolTableCollection;
+ mesh::allScatterOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+ LogicalResult status =
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ (void)status;
+ assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ mesh::allScatterOpLoweringRegisterDialects(registry);
+ }
+ StringRef getArgument() const final {
+ return "test-mesh-all-scatter-op-lowering";
+ }
+ StringRef getDescription() const final {
+ return "Test lowering of all-scatter.";
+ }
+};
+
+struct TestMultiIndexOpLoweringPass
+ : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ SymbolTableCollection symbolTableCollection;
+ mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
+ symbolTableCollection);
+ LogicalResult status =
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ (void)status;
+ assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ mesh::processMultiIndexOpLoweringRegisterDialects(registry);
+ }
+ StringRef getArgument() const final {
+ return "test-mesh-process-multi-index-op-lowering";
+ }
+ StringRef getDescription() const final {
+ return "Test lowering of mesh.process_multi_index op.";
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestOpLoweringPasses() {
+ PassRegistration<TestAllScatterOpLoweringPass>();
+ PassRegistration<TestMultiIndexOpLoweringPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
deleted file mode 100644
index 0bcc403a2734ed..00000000000000
--- a/mlir/test/lib/Dialect/Mesh/TestProcessMultiIndexOpLowering.cpp
+++ /dev/null
@@ -1,54 +0,0 @@
-//===- TestProcessMultiIndexOpLowering.cpp --------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-using namespace mlir;
-
-namespace {
-struct TestMultiIndexOpLoweringPass
- : public PassWrapper<TestMultiIndexOpLoweringPass, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiIndexOpLoweringPass)
-
- void runOnOperation() override;
- void getDependentDialects(DialectRegistry ®istry) const override {
- mesh::processMultiIndexOpLoweringRegisterDialects(registry);
- }
- StringRef getArgument() const final {
- return "test-mesh-process-multi-index-op-lowering";
- }
- StringRef getDescription() const final {
- return "Test lowering of mesh.process_multi_index op.";
- }
-};
-} // namespace
-
-void TestMultiIndexOpLoweringPass::runOnOperation() {
- RewritePatternSet patterns(&getContext());
- SymbolTableCollection symbolTableCollection;
- mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
- symbolTableCollection);
- LogicalResult status =
- applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
- (void)status;
- assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
-}
-
-namespace mlir {
-namespace test {
-void registerTestMultiIndexOpLoweringPass() {
- PassRegistration<TestMultiIndexOpLoweringPass>();
-}
-} // namespace test
-} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 428bdd9691e095..a94f42927f3f3b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -120,7 +120,7 @@ void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
void registerTestMeshSimplificationsPass();
void registerTestMeshReshardingSpmdizationPass();
-void registerTestMultiIndexOpLoweringPass();
+void registerTestOpLoweringPasses();
void registerTestNextAccessPass();
void registerTestOneToNTypeConversionPass();
void registerTestOpaqueLoc();
@@ -241,7 +241,7 @@ void registerTestPasses() {
mlir::test::registerTestMathPolynomialApproximationPass();
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
- mlir::test::registerTestMultiIndexOpLoweringPass();
+ mlir::test::registerTestOpLoweringPasses();
mlir::test::registerTestMeshSimplificationsPass();
mlir::test::registerTestMeshReshardingSpmdizationPass();
mlir::test::registerTestNextAccessPass();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 9f2c1f1120066b..8943e91c1f0ceb 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3314,7 +3314,9 @@ cc_library(
includes = ["include"],
deps = [
":AffineDialect",
+ ":AffineUtils",
":ArithDialect",
+ ":ArithUtils",
":ControlFlowDialect",
":DialectUtils",
":FuncDialect",
>From 8cbbb177dc5ba8803e4b068672dc679ef230fe80 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Thu, 8 Feb 2024 19:10:01 -0800
Subject: [PATCH 2/3] Remove unused resharding dependencies
---
mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp | 6 +-----
1 file changed, 1 insertion(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 21e03d9572c590..cd0481d7b87aec 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -8,9 +8,6 @@
#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
-#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
@@ -519,8 +516,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
}
void reshardingRegisterDependentDialects(DialectRegistry ®istry) {
- registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect,
- cf::ControlFlowDialect>();
+ registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
}
#define GEN_PASS_DEF_SPMDIZATION
>From 73cffb0c1cb1d9a4e02420ff12c3ee9cbcd251ba Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Tue, 13 Feb 2024 13:33:56 -0800
Subject: [PATCH 3/3] Address PR comments
---
mlir/include/mlir/Dialect/Affine/Utils.h | 10 +--
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 18 ++---
.../mlir/Dialect/Mesh/Transforms/Transforms.h | 8 +-
mlir/lib/Dialect/Affine/Utils/Utils.cpp | 21 +++---
mlir/lib/Dialect/Arith/Utils/Utils.cpp | 2 +-
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 45 ++++++------
.../Dialect/Mesh/Transforms/Spmdization.cpp | 6 +-
.../Dialect/Mesh/Transforms/Transforms.cpp | 73 +++++++++++--------
.../Dialect/Mesh/all-scatter-op-lowering.mlir | 20 ++---
mlir/test/Dialect/Mesh/canonicalization.mlir | 8 +-
mlir/test/Dialect/Mesh/invalid.mlir | 30 ++++----
mlir/test/Dialect/Mesh/ops.mlir | 20 ++---
.../Dialect/Mesh/resharding-spmdization.mlir | 6 +-
mlir/test/Dialect/Mesh/spmdization.mlir | 6 +-
mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp | 20 ++---
15 files changed, 152 insertions(+), 141 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 441ce1952210bb..9f37528c27046e 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -311,11 +311,11 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
Value linearIndex,
ArrayRef<Value> basis);
-// Generate IR that extracts the linear index form a multi-index according to
-// a shape.
-OpFoldResult linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
- ArrayRef<OpFoldResult> shape,
- ImplicitLocOpBuilder &builder);
+// Generate IR that extracts the linear index from a multi-index according to
+// a basis/shape.
+OpFoldResult linearIndex(ArrayRef<OpFoldResult> multiIndex,
+ ArrayRef<OpFoldResult> basis,
+ ImplicitLocOpBuilder &builder);
/// Ensure that all operations that could be executed after `start`
/// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 7792aac784d4be..96636d5347ff6e 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -342,15 +342,15 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
let hasCanonicalizer = 1;
}
-def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
+def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
Pure,
SameOperandsAndResultElementType,
SameOperandsAndResultRank
]> {
- let summary = "All-scatter over a device mesh. This is the inverse of all-gather.";
+ let summary = "All-slice over a device mesh. This is the inverse of all-gather.";
let description = [{
- Scatter along the `scatter_axis` tensor axis.
- This operation can be taught of as the inverse of all-gather.
+ Slice along the `slice_axis` tensor axis.
+ This operation can be thought of as the inverse of all-gather.
Technically, it is not required that all processes have the same input tensor.
Each process will slice a piece of its local tensor based on its in-group device index.
The operation does not communicate data between devices.
@@ -359,7 +359,7 @@ def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
```mlir
mesh.mesh @mesh0(shape = 2x2)
...
- %1 = mesh.all_scatter %0 on @mesh0 mesh_axes = [1] scatter_axis = 1
+ %1 = mesh.all_slice %0 on @mesh0 mesh_axes = [1] slice_axis = 1
: tensor<2x4xi8> -> tensor<2x2xi8>
```
Input:
@@ -388,19 +388,19 @@ def Mesh_AllScatterOp : Mesh_CollectiveCommunicationOpBase<"all_scatter", [
}];
let arguments = !con(commonArgs, (ins
AnyNon0RankedTensor:$input,
- IndexAttr:$scatter_axis
+ IndexAttr:$slice_axis
));
let results = (outs
AnyNon0RankedTensor:$result
);
let assemblyFormat = [{
- $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `scatter_axis` `=` $scatter_axis
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `slice_axis` `=` $slice_axis
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
let builders = [
- OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>,
- OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$scatterAxis)>
+ OpBuilder<(ins "Value":$input, "MeshOp":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>,
+ OpBuilder<(ins "Type":$result_type, "Value":$input, "StringRef":$mesh, "ArrayRef<MeshAxis>":$meshAxes, "int64_t":$sliceAxis)>
];
}
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index d398bdd65330b1..aeab28961a4e1e 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -21,13 +21,13 @@ class DialectRegistry;
class ImplicitLocOpBuilder;
namespace mesh {
-void processMultiIndexOpLoweringPopulatePatterns(
+void populateProcessMultiIndexOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
-void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry);
+void registerProcessMultiIndexOpLoweringDialects(DialectRegistry ®istry);
-void allScatterOpLoweringPopulatePatterns(
+void populateAllSliceOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
-void allScatterOpLoweringRegisterDialects(DialectRegistry ®istry);
+void registerAllSliceOpLoweringDialects(DialectRegistry ®istry);
void populateAllOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 41bcd2b9f33e6b..3afecd747f401c 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1872,23 +1872,22 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
return results;
}
-OpFoldResult
-mlir::affine::linearIndexFromShape(ArrayRef<OpFoldResult> multiIndex,
- ArrayRef<OpFoldResult> shape,
- ImplicitLocOpBuilder &builder) {
- assert(multiIndex.size() == shape.size());
- SmallVector<AffineExpr> shapeAffine;
- for (size_t i = 0; i < shape.size(); ++i) {
- shapeAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
+OpFoldResult mlir::affine::linearIndex(ArrayRef<OpFoldResult> multiIndex,
+ ArrayRef<OpFoldResult> basis,
+ ImplicitLocOpBuilder &builder) {
+ assert(multiIndex.size() == basis.size());
+ SmallVector<AffineExpr> basisAffine;
+ for (size_t i = 0; i < basis.size(); ++i) {
+ basisAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
}
- SmallVector<AffineExpr> stridesAffine = computeStrides(shapeAffine);
+ SmallVector<AffineExpr> stridesAffine = computeStrides(basisAffine);
SmallVector<OpFoldResult> strides;
strides.reserve(stridesAffine.size());
llvm::transform(stridesAffine, std::back_inserter(strides),
- [&builder, &shape](AffineExpr strideExpr) {
+ [&builder, &basis](AffineExpr strideExpr) {
return affine::makeComposedFoldedAffineApply(
- builder, builder.getLoc(), strideExpr, shape);
+ builder, builder.getLoc(), strideExpr, basis);
});
auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index 999cdbc5c10008..aa239f5e053969 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -280,4 +280,4 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
[&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
}
-} // namespace mlir::arith
\ No newline at end of file
+} // namespace mlir::arith
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 762725d2c56e66..91165765245c6d 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -558,7 +558,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
return success();
}
-static LogicalResult verifyScatterOperandAndResultShape(
+static LogicalResult verifyScatterOrSliceOperandAndResultShape(
Value operand, Value result, int64_t scatterAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
ShapedType operandType = operand.getType().cast<ShapedType>();
@@ -582,7 +582,7 @@ static LogicalResult verifyScatterOperandAndResultShape(
return emitError(result.getLoc())
<< "Operand dimension size " << int64_t(operandScatterDimSize)
<< " is not divisible by collective device group size "
- << int64_t(deviceGroupSize) << " for scatter axis " << scatterAxis
+ << int64_t(deviceGroupSize) << " for tensor axis " << scatterAxis
<< ".";
}
DimensionSize expectedResultScatterDimSize =
@@ -648,39 +648,38 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}
//===----------------------------------------------------------------------===//
-// mesh.all_scatter op
+// mesh.all_slice op
//===----------------------------------------------------------------------===//
-LogicalResult
-AllScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
if (failed(mesh)) {
return failure();
}
- return verifyScatterOperandAndResultShape(
- getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
+ return verifyScatterOrSliceOperandAndResultShape(
+ getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(),
mesh.value().getShape());
}
-void AllScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
- MLIRContext *context) {
- patterns.add<EmptyMeshAxesCanonicalizationPattern<AllScatterOp>>(context);
+void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context);
}
-void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
- int64_t scatterAxis) {
+void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
+ int64_t sliceAxis) {
Type resultType =
- scatterResultType(input.getType(), mesh, meshAxes, scatterAxis);
+ scatterResultType(input.getType(), mesh, meshAxes, sliceAxis);
build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
- scatterAxis);
+ sliceAxis);
}
-void AllScatterOp::build(OpBuilder &odsBuilder, OperationState &odsState,
- Type resultType, Value input, StringRef mesh,
- ArrayRef<MeshAxis> meshAxes, int64_t scatterAxis) {
+void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ Type resultType, Value input, StringRef mesh,
+ ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) {
build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
- APInt(sizeof(scatterAxis) * CHAR_BIT, scatterAxis));
+ APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis));
}
//===----------------------------------------------------------------------===//
@@ -810,7 +809,7 @@ ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return failure();
}
- return verifyScatterOperandAndResultShape(
+ return verifyScatterOrSliceOperandAndResultShape(
getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(),
mesh.value().getShape());
}
@@ -836,9 +835,9 @@ LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
auto scatterAxis = getScatterAxis().getSExtValue();
- return verifyScatterOperandAndResultShape(getInput(), getResult(),
- scatterAxis, getMeshAxes(),
- mesh.value().getShape());
+ return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(),
+ scatterAxis, getMeshAxes(),
+ mesh.value().getShape());
}
void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index cd0481d7b87aec..7cbe0de048769b 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -135,9 +135,9 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
TypedValue<ShapedType> targetShard =
builder
- .create<AllScatterOp>(sourceShard, mesh,
- ArrayRef<MeshAxis>(splitMeshAxis),
- splitTensorAxis)
+ .create<AllSliceOp>(sourceShard, mesh,
+ ArrayRef<MeshAxis>(splitMeshAxis),
+ splitTensorAxis)
.getResult()
.cast<TypedValue<ShapedType>>();
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index 95739c18ce0f36..c8c75b94a42206 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -51,7 +51,6 @@ struct ProcessMultiIndexOpLowering
builder.setInsertionPointAfter(op.getOperation());
Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
- llvm::errs() << "meshShape.size() = " << meshShape.size() << "\n";
SmallVector<Value> completeMultiIndex =
builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
.getMultiIndex();
@@ -72,13 +71,28 @@ struct ProcessMultiIndexOpLowering
}
};
-struct AllScatterOpLowering
- : OpRewritePatternWithSymbolTableCollection<AllScatterOp> {
+struct AllSliceOpLowering
+ : OpRewritePatternWithSymbolTableCollection<AllSliceOp> {
using OpRewritePatternWithSymbolTableCollection::
OpRewritePatternWithSymbolTableCollection;
- LogicalResult matchAndRewrite(AllScatterOp op,
+ LogicalResult matchAndRewrite(AllSliceOp op,
PatternRewriter &rewriter) const override {
+ // 1. Compute the process linear index inside the process group from its
+ // multi-index.
+ //
+ // 2. Extract a slice from the input tensor.
+ // All axes except the slicing axis are not interesting and take the full
+ // axis.
+ // The slice axis is split into equisized parts with count
+ // the number of processes in the collective process group induced by
+ // the mesh axes.
+ // The part for each process is determined by the corresponding
+ // linear-index in the process group.
+ //
+ // There are no collectives that require communication.
+ // Each process operates on its local tensor.
+
MeshOp mesh = getMesh(op, symbolTableCollection);
if (!mesh) {
return failure();
@@ -99,32 +113,31 @@ struct AllScatterOpLowering
Value processGroupSize =
createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
- int64_t scatterAxis = op.getScatterAxis().getSExtValue();
- Value operandScatterAxisSize =
- builder.create<tensor::DimOp>(op.getOperand(), scatterAxis);
- Value operandScatterAxisSizeModProcessGroupSize =
- builder.create<arith::RemUIOp>(operandScatterAxisSize,
- processGroupSize);
+ int64_t sliceAxis = op.getSliceAxis().getSExtValue();
+ Value operandSliceAxisSize =
+ builder.create<tensor::DimOp>(op.getOperand(), sliceAxis);
+ Value operandSliceAxisSizeModProcessGroupSize =
+ builder.create<arith::RemUIOp>(operandSliceAxisSize, processGroupSize);
Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
- arith::CmpIPredicate::eq, operandScatterAxisSizeModProcessGroupSize,
+ arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize,
zero);
builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
- "Scattering a tensor with axis size that is "
+ "Slicing a tensor with axis size that is "
"not exactly divisible by the "
"mesh process group size is not supported.");
- Value resultScatterAxisSize = builder.create<arith::DivUIOp>(
- operandScatterAxisSize, processGroupSize);
- OpFoldResult processInGroupLinearIndex = affine::linearIndexFromShape(
+ Value resultSliceAxisSize =
+ builder.create<arith::DivUIOp>(operandSliceAxisSize, processGroupSize);
+ OpFoldResult processInGroupLinearIndex = affine::linearIndex(
llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
- // extract slice
+ // insert tensor.extract_slice
RankedTensorType operandType =
op.getOperand().getType().cast<RankedTensorType>();
SmallVector<OpFoldResult> sizes;
for (int64_t i = 0; i < operandType.getRank(); ++i) {
- if (i == scatterAxis) {
- sizes.emplace_back(resultScatterAxisSize);
+ if (i == sliceAxis) {
+ sizes.emplace_back(resultSliceAxisSize);
} else {
Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i);
sizes.emplace_back(dimSize);
@@ -132,11 +145,11 @@ struct AllScatterOpLowering
}
SmallVector<OpFoldResult> offsets(
operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0));
- offsets[scatterAxis] =
+ offsets[sliceAxis] =
ArithBuilder(builder, builder.getLoc())
.mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
processInGroupLinearIndex),
- resultScatterAxisSize);
+ resultSliceAxisSize);
SmallVector<OpFoldResult> strides(
operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1));
Value slice = builder.create<tensor::ExtractSliceOp>(
@@ -151,23 +164,23 @@ struct AllScatterOpLowering
} // namespace
-void processMultiIndexOpLoweringPopulatePatterns(
+void populateProcessMultiIndexOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
patterns.getContext());
}
-void processMultiIndexOpLoweringRegisterDialects(DialectRegistry ®istry) {
+void registerProcessMultiIndexOpLoweringDialects(DialectRegistry ®istry) {
registry.insert<affine::AffineDialect, mesh::MeshDialect>();
}
-void allScatterOpLoweringPopulatePatterns(
+void populateAllSliceOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
- patterns.add<AllScatterOpLowering>(symbolTableCollection,
- patterns.getContext());
+ patterns.add<AllSliceOpLowering>(symbolTableCollection,
+ patterns.getContext());
}
-void allScatterOpLoweringRegisterDialects(DialectRegistry ®istry) {
+void registerAllSliceOpLoweringDialects(DialectRegistry ®istry) {
registry.insert<affine::AffineDialect, arith::ArithDialect,
cf::ControlFlowDialect, mesh::MeshDialect,
tensor::TensorDialect>();
@@ -175,13 +188,13 @@ void allScatterOpLoweringRegisterDialects(DialectRegistry ®istry) {
void populateAllOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
- processMultiIndexOpLoweringPopulatePatterns(patterns, symbolTableCollection);
- allScatterOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+ populateProcessMultiIndexOpLoweringPatterns(patterns, symbolTableCollection);
+ populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
}
void registerAllOpLoweringDialects(DialectRegistry ®istry) {
- processMultiIndexOpLoweringRegisterDialects(registry);
- allScatterOpLoweringRegisterDialects(registry);
+ registerProcessMultiIndexOpLoweringDialects(registry);
+ registerAllSliceOpLoweringDialects(registry);
}
TypedValue<IndexType>
diff --git a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir b/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
index 8aa3fc5160a001..4f54607a1c7ff0 100644
--- a/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
+++ b/mlir/test/Dialect/Mesh/all-scatter-op-lowering.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt --split-input-file --test-mesh-all-scatter-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --test-mesh-all-slice-op-lowering --test-mesh-simplifications --cse %s | FileCheck %s
mesh.mesh @mesh_1d(shape = ?)
-// CHECK-LABEL: func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh
-func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh
+func.func @all_slice_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
// CHECK: %[[ARG:.*]]: tensor<?xf16>
%arg0: tensor<?xf16>
// CHECK-SAME: -> tensor<?xf16> {
@@ -18,7 +18,7 @@ func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
// CHECK: %[[RESULT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_AXIS_SIZE]], %[[MESH_SIZE]] : index
// CHECK: %[[SLICE_OFFSET:.*]] = arith.muli %[[PROC_IDX]], %[[RESULT_AXIS_SIZE]] : index
// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][%[[SLICE_OFFSET]]] [%[[RESULT_AXIS_SIZE]]] [1] : tensor<?xf16> to tensor<?xf16>
- %0 = mesh.all_scatter %arg0 on @mesh_1d mesh_axes = [0] scatter_axis = 0 : tensor<?xf16> -> tensor<?xf16>
+ %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<?xf16> -> tensor<?xf16>
// CHECK: return %[[RESULT]] : tensor<?xf16>
return %0 : tensor<?xf16>
}
@@ -27,8 +27,8 @@ func.func @all_scatter_op_lowering_of_dynamic_1d_tensor_on_dynamic_1d_mesh(
mesh.mesh @mesh_1d(shape = 2)
-// CHECK-LABEL: func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh
-func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh
+func.func @all_slice_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
// CHECK: %[[ARG:.*]]: tensor<2xf16>
%arg0: tensor<2xf16>
// CHECK-SAME: -> tensor<1xf16> {
@@ -37,7 +37,7 @@ func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
// CHECK: %[[PROC_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[PROC_IDX]]] [%[[C1]]] [1] : tensor<2xf16> to tensor<?xf16>
// CHECK: %[[RESULT:.*]] = tensor.cast %[[SLICE]] : tensor<?xf16> to tensor<1xf16>
- %0 = mesh.all_scatter %arg0 on @mesh_1d mesh_axes = [0] scatter_axis = 0 : tensor<2xf16> -> tensor<1xf16>
+ %0 = mesh.all_slice %arg0 on @mesh_1d mesh_axes = [0] slice_axis = 0 : tensor<2xf16> -> tensor<1xf16>
// CHECK: return %[[RESULT]] : tensor<1xf16>
return %0 : tensor<1xf16>
}
@@ -48,8 +48,8 @@ func.func @all_scatter_op_lowering_of_static_1d_tensor_on_static_1d_mesh(
mesh.mesh @mesh_4d(shape = ?x?x?x?)
-// CHECK-LABEL: func.func @all_scatter_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh
-func.func @all_scatter_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
+// CHECK-LABEL: func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh
+func.func @all_slice_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
// CHECK: %[[ARG:.*]]: tensor<?x?xf16>
%arg0 : tensor<?x?xf16>
// CHECK-SAME: -> tensor<?x?xf16> {
@@ -68,7 +68,7 @@ func.func @all_scatter_op_lowering_of_dynamic_2d_tensor_on_dynamic_4d_mesh(
// CHECK: %[[AXIS_0_SIZE:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf16>
// CHECK: %[[SCATTER_AXIS_OFFSET:.*]] = arith.muli %[[PROC_IN_GROUP_LINEAR_IDX]], %[[RESULT_SCATTER_AXIS_SIZE]] : index
// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[ARG]][0, %[[SCATTER_AXIS_OFFSET]]] [%[[AXIS_0_SIZE]], %[[RESULT_SCATTER_AXIS_SIZE]]] [1, 1] : tensor<?x?xf16> to tensor<?x?xf16>
- %0 = mesh.all_scatter %arg0 on @mesh_4d mesh_axes = [3, 1] scatter_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
+ %0 = mesh.all_slice %arg0 on @mesh_4d mesh_axes = [3, 1] slice_axis = 1 : tensor<?x?xf16> -> tensor<?x?xf16>
// CHECK: return %[[RESULT]] : tensor<?x?xf16>
return %0 : tensor<?x?xf16>
}
diff --git a/mlir/test/Dialect/Mesh/canonicalization.mlir b/mlir/test/Dialect/Mesh/canonicalization.mlir
index 4fa9c81f4cd004..633324ae680eb1 100644
--- a/mlir/test/Dialect/Mesh/canonicalization.mlir
+++ b/mlir/test/Dialect/Mesh/canonicalization.mlir
@@ -63,14 +63,14 @@ func.func @all_gather_empty_mesh_axes(
return %0 : tensor<4xf32>
}
-// CHECK-LABEL: func @all_scatter_empty_mesh_axes
-func.func @all_scatter_empty_mesh_axes(
+// CHECK-LABEL: func @all_slice_empty_mesh_axes
+func.func @all_slice_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
%arg0 : tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: mesh.scatter
- %0 = mesh.all_scatter %arg0 on @mesh0
+ %0 = mesh.all_slice %arg0 on @mesh0
mesh_axes = []
- scatter_axis = 0
+ slice_axis = 0
: tensor<4xf32> -> tensor<4xf32>
// CHECK: return %[[ARG]]
return %0 : tensor<4xf32>
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 590d548d9c769e..6d7df86d78406f 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -318,11 +318,11 @@ func.func @all_gather_invalid_negative_gather_axis(
mesh.mesh @mesh0(shape = 3)
-func.func @all_scatter_duplicate_mesh_axis(
+func.func @all_slice_duplicate_mesh_axis(
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
// expected-error at +1 {{Mesh axes contains duplicate elements.}}
- %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0, 0]
- scatter_axis = 0
+ %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0, 0]
+ slice_axis = 0
: tensor<?xf32> -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@@ -331,11 +331,11 @@ func.func @all_scatter_duplicate_mesh_axis(
mesh.mesh @mesh0(shape = 3)
-func.func @all_scatter_invalid_dynamic_dimension(
+func.func @all_slice_invalid_dynamic_dimension(
%arg0 : tensor<?xf32>) -> tensor<2xf32> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected dynamic, but got 2.}}
- %0 = mesh.all_scatter %arg0 on @mesh0
- scatter_axis = 0
+ %0 = mesh.all_slice %arg0 on @mesh0
+ slice_axis = 0
: tensor<?xf32> -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@@ -344,11 +344,11 @@ func.func @all_scatter_invalid_dynamic_dimension(
mesh.mesh @mesh0(shape = 3)
-func.func @all_scatter_invalid_static_dimension_size(
+func.func @all_slice_invalid_static_dimension_size(
%arg0 : tensor<3xf32>) -> tensor<2xf32> {
// expected-error at +1 {{Dimension size mismatch for result axis 0. Expected 1, but got 2.}}
- %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0]
- scatter_axis = 0
+ %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0]
+ slice_axis = 0
: tensor<3xf32> -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@@ -357,11 +357,11 @@ func.func @all_scatter_invalid_static_dimension_size(
mesh.mesh @mesh0(shape = 3)
-func.func @all_scatter_invalid_operand_static_dimension_size(
+func.func @all_slice_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf32> {
- // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
- %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [0]
- scatter_axis = 0
+ // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
+ %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [0]
+ slice_axis = 0
: tensor<4xf32> -> tensor<?xf32>
return %0 : tensor<?xf32>
}
@@ -712,7 +712,7 @@ mesh.mesh @mesh0(shape = 3)
func.func @reduce_scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf64> {
- // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+ // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
%0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [0] scatter_axis = 0
: tensor<4xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
@@ -763,7 +763,7 @@ mesh.mesh @mesh0(shape = 3)
func.func @scatter_invalid_operand_static_dimension_size(
%arg0 : tensor<4xf32>) -> tensor<?xf32> {
- // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for scatter axis 0.}}
+ // expected-error at +1 {{Operand dimension size 4 is not divisible by collective device group size 3 for tensor axis 0.}}
%0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0]
scatter_axis = 0 root = [1]
: (tensor<4xf32>) -> tensor<?xf32>
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 69e8545f3cb86f..6e5df86b13106a 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -208,26 +208,26 @@ func.func @all_gather_dynamic_dims_in_mesh(
return %0 : tensor<5x?xf32>
}
-// CHECK-LABEL: func @all_scatter_static_dimensions
-func.func @all_scatter_static_dimensions(
+// CHECK-LABEL: func @all_slice_static_dimensions
+func.func @all_slice_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
- // CHECK-NEXT: mesh.all_scatter %[[ARG]]
- // CHECK-SAME: on @mesh0 mesh_axes = [2] scatter_axis = 1
+ // CHECK-NEXT: mesh.all_slice %[[ARG]]
+ // CHECK-SAME: on @mesh0 mesh_axes = [2] slice_axis = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32>
- %0 = mesh.all_scatter %arg0 on @mesh0 mesh_axes = [2] scatter_axis = 1
+ %0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [2] slice_axis = 1
: tensor<3x4xf32> -> tensor<3x1xf32>
return %0 : tensor<3x1xf32>
}
-// CHECK-LABEL: func @all_scatter_dynamic_dimensions
-func.func @all_scatter_dynamic_dimensions(
+// CHECK-LABEL: func @all_slice_dynamic_dimensions
+func.func @all_slice_dynamic_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
- // CHECK-NEXT: mesh.all_scatter %[[ARG]]
- // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+ // CHECK-NEXT: mesh.all_slice %[[ARG]]
+ // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] slice_axis = 0
// CHECK-SAME: : tensor<?xf32> -> tensor<?xf32>
- %0 = mesh.all_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
+ %0 = mesh.all_slice %arg0 on @mesh3 mesh_axes = [0, 1] slice_axis = 0
: tensor<?xf32> -> tensor<?xf32>
return %0 : tensor<?xf32>
}
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
index fdf9905a62e99c..ba05306598bcc6 100644
--- a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -19,9 +19,9 @@ func.func @split_replicated_tensor_axis(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
%arg0: tensor<3x14xf32>
) -> tensor<3x14xf32> {
- // CHECK: %[[ALL_SCATTER:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 1
+ // CHECK: %[[ALL_SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 1
// CHECK-SAME: tensor<3x14xf32> -> tensor<3x7xf32>
- // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SCATTER]] : tensor<3x7xf32> to tensor<3x14xf32>
+ // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ALL_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
%0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32>
%1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32>
// CHECK: return %[[RESULT]] : tensor<3x14xf32>
@@ -33,7 +33,7 @@ func.func @split_replicated_tensor_axis_dynamic(
// CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
%arg0: tensor<?x3x?xf32>
) -> tensor<?x3x?xf32> {
- // CHECK: %[[RESULT:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] scatter_axis = 0
+ // CHECK: %[[RESULT:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d_dynamic mesh_axes = [0] slice_axis = 0
// CHECK-SAME: tensor<?x3x?xf32> -> tensor<?x3x?xf32>
%0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32>
%1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32>
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 3aeea169107d66..2fb8029dfe64ae 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -63,7 +63,7 @@ func.func @unary_elementwise_with_resharding(
%arg0: tensor<2xi8>
// CHECK-SAME: -> tensor<2xi8> {
) -> tensor<2xi8> {
- // CHECK: %[[SLICE:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 0
+ // CHECK: %[[SLICE:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
// CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
%0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
%1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
@@ -108,7 +108,7 @@ func.func @multiple_chained_ops(
%arg0: tensor<2xi8>
// CHECK-SAME: -> tensor<1xi8> {
) -> tensor<2xi8> {
- // CHECK: %[[RESHARD1:.*]] = mesh.all_scatter %[[ARG]] on @mesh_1d mesh_axes = [0] scatter_axis = 0
+ // CHECK: %[[RESHARD1:.*]] = mesh.all_slice %[[ARG]] on @mesh_1d mesh_axes = [0] slice_axis = 0
// CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
%0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xi8>
%1 = mesh.shard %0 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
@@ -120,7 +120,7 @@ func.func @multiple_chained_ops(
%4 = mesh.shard %3 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8>
// CHECK: %[[ABS2:.*]] = tosa.abs %[[RESHARD2]] : (tensor<2xi8>) -> tensor<2xi8>
%5 = tosa.abs %4 : (tensor<2xi8>) -> tensor<2xi8>
- // CHECK: %[[RESHARD3:.*]] = mesh.all_scatter %[[ABS2]] on @mesh_1d mesh_axes = [0] scatter_axis = 0 :
+ // CHECK: %[[RESHARD3:.*]] = mesh.all_slice %[[ABS2]] on @mesh_1d mesh_axes = [0] slice_axis = 0 :
// CHECK-SAME: tensor<2xi8> -> tensor<1xi8>
%6 = mesh.shard %5 to <@mesh_1d, [[]]> : tensor<2xi8>
%7 = mesh.shard %6 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
diff --git a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
index 2c9131e666ed69..321b6a42bf9667 100644
--- a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
+++ b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp
@@ -18,27 +18,27 @@ using namespace mlir;
namespace {
-struct TestAllScatterOpLoweringPass
- : public PassWrapper<TestAllScatterOpLoweringPass, OperationPass<>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllScatterOpLoweringPass)
+struct TestAllSliceOpLoweringPass
+ : public PassWrapper<TestAllSliceOpLoweringPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllSliceOpLoweringPass)
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
SymbolTableCollection symbolTableCollection;
- mesh::allScatterOpLoweringPopulatePatterns(patterns, symbolTableCollection);
+ mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
LogicalResult status =
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)status;
assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
}
void getDependentDialects(DialectRegistry ®istry) const override {
- mesh::allScatterOpLoweringRegisterDialects(registry);
+ mesh::registerAllSliceOpLoweringDialects(registry);
}
StringRef getArgument() const final {
- return "test-mesh-all-scatter-op-lowering";
+ return "test-mesh-all-slice-op-lowering";
}
StringRef getDescription() const final {
- return "Test lowering of all-scatter.";
+ return "Test lowering of all-slice.";
}
};
@@ -49,7 +49,7 @@ struct TestMultiIndexOpLoweringPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
SymbolTableCollection symbolTableCollection;
- mesh::processMultiIndexOpLoweringPopulatePatterns(patterns,
+ mesh::populateProcessMultiIndexOpLoweringPatterns(patterns,
symbolTableCollection);
LogicalResult status =
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
@@ -57,7 +57,7 @@ struct TestMultiIndexOpLoweringPass
assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
}
void getDependentDialects(DialectRegistry ®istry) const override {
- mesh::processMultiIndexOpLoweringRegisterDialects(registry);
+ mesh::registerProcessMultiIndexOpLoweringDialects(registry);
}
StringRef getArgument() const final {
return "test-mesh-process-multi-index-op-lowering";
@@ -72,7 +72,7 @@ struct TestMultiIndexOpLoweringPass
namespace mlir {
namespace test {
void registerTestOpLoweringPasses() {
- PassRegistration<TestAllScatterOpLoweringPass>();
+ PassRegistration<TestAllSliceOpLoweringPass>();
PassRegistration<TestMultiIndexOpLoweringPass>();
}
} // namespace test
More information about the Mlir-commits
mailing list