[Mlir-commits] [mlir] 7a4c497 - [mlir][mesh] Use one type for mesh axis (#76830)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 3 15:47:15 PST 2024
Author: Boian Petkantchin
Date: 2024-01-03T15:47:11-08:00
New Revision: 7a4c49756db161ebcce08c7bc860a569aad7f276
URL: https://github.com/llvm/llvm-project/commit/7a4c49756db161ebcce08c7bc860a569aad7f276
DIFF: https://github.com/llvm/llvm-project/commit/7a4c49756db161ebcce08c7bc860a569aad7f276.diff
LOG: [mlir][mesh] Use one type for mesh axis (#76830)
Make all ops and attributes use the types MeshAxis and MeshAxesAttr
instead of int16_t, int32_t, DenseI16ArrayAttr and DenseI32ArrayAttr.
Added:
Modified:
mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index a9d30dfbb9a76e..060d54b82efa63 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -80,8 +80,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let parameters = (ins
AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
- ArrayRefParameter<"::mlir::DenseI32ArrayAttr">:$split_axes,
- OptionalArrayRefParameter<"int32_t">:$partial_axes,
+ ArrayRefParameter<"MeshAxesAttr">:$split_axes,
+ OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
);
@@ -146,18 +146,18 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let builders = [
AttrBuilder<(ins "SymbolRefAttr":$cluster,
- "ArrayRef<SmallVector<int32_t>>":$split_axes,
- "ArrayRef<int32_t>": $partial_axes,
+ "ArrayRef<SmallVector<MeshAxis>>":$split_axes,
+ "ArrayRef<MeshAxis>": $partial_axes,
"mesh::Partial": $partial_type), [{
- SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::map_to_vector(
- split_axes, [&](ArrayRef<int32_t> array) {
- return DenseI32ArrayAttr::get($_ctxt, array);
+ SmallVector<MeshAxesAttr> splitAxesAttr = llvm::map_to_vector(
+ split_axes, [&](ArrayRef<MeshAxis> array) {
+ return MeshAxesAttr::get($_ctxt, array);
});
return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
partial_type);
}]>,
AttrBuilder<(ins "SymbolRefAttr":$cluster,
- "ArrayRef<SmallVector<int32_t>>":$split_axes), [{
+ "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
}]>
];
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index ce7d5d045122d9..83452dcc2e8abe 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -17,6 +17,15 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <algorithm>
+namespace mlir {
+namespace mesh {
+
+using MeshAxis = int16_t;
+using MeshAxesAttr = DenseI16ArrayAttr;
+
+} // namespace mesh
+} // namespace mlir
+
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.h.inc"
@@ -30,9 +39,6 @@
namespace mlir {
namespace mesh {
-using MeshAxis = int16_t;
-using MeshAxesAttr = DenseI16ArrayAttr;
-
bool isReductionLoop(IteratorType iType);
bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 1ed54b6519e4d8..1934bdfb427059 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -114,7 +114,7 @@ def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMeth
let builders = [
OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
- OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
];
}
@@ -228,7 +228,7 @@ def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMeth
}];
let builders = [
OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
- OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<MeshAxis>":$axes)>
];
}
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 270955a3036e89..201c0151754eba 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -18,8 +18,8 @@ class Operation;
namespace mesh {
-using ShardingArray = SmallVector<SmallVector<int32_t>>;
-using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
+using ShardingArray = SmallVector<SmallVector<MeshAxis>>;
+using ShardingArrayRef = ArrayRef<SmallVector<MeshAxis>>;
struct ShardingOption {
// An array of int array. The sub-array at the i-th position signifies the
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index de4f58d54e8ca5..c3d8f1d456106d 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -266,15 +266,15 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- SymbolRefAttr, ArrayRef<DenseI32ArrayAttr> splitAxes,
- ArrayRef<int32_t> partialAxes, Partial) {
+ SymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
+ ArrayRef<MeshAxis> partialAxes, Partial) {
// TODO: At present cluster symbol ref is not verified. This is due to the
//
diff iculty in fetching the corresponding symbol op based on an attribute.
- llvm::SmallSet<int32_t, 4> visitedAxes;
+ llvm::SmallSet<MeshAxis, 4> visitedAxes;
- auto checkMeshAxis = [&](ArrayRef<int32_t> axesArray) -> LogicalResult {
- for (int32_t axis : axesArray) {
+ auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult {
+ for (MeshAxis axis : axesArray) {
if (axis < 0)
return emitError() << "mesh axis is expected to be non-negative";
if (!visitedAxes.insert(axis).second)
@@ -283,8 +283,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
};
- for (DenseI32ArrayAttr subAxes : splitAxes) {
- ArrayRef<int32_t> subAxesArray = subAxes.asArrayRef();
+ for (MeshAxesAttr subAxes : splitAxes) {
+ ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef();
if (failed(checkMeshAxis(subAxesArray)))
return failure();
}
@@ -318,10 +318,10 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
getSplitAxes().end()),
- std::mem_fn(&DenseI32ArrayAttr::empty)) &&
+ std::mem_fn(&MeshAxesAttr::empty)) &&
llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
rhs.getSplitAxes().end()),
- std::mem_fn(&DenseI32ArrayAttr::empty));
+ std::mem_fn(&MeshAxesAttr::empty));
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index a6f2f435f36d68..ee885ab16b7b06 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -216,7 +216,7 @@ namespace {
static LogicalResult fillShardingOption(Operation *op,
ShardingOption &shardingOption,
SymbolRefAttr cluster,
- ArrayRef<int32_t> meshAxes,
+ ArrayRef<MeshAxis> meshAxes,
unsigned loopIdx) {
if ((shardingOption.cluster && cluster &&
shardingOption.cluster != cluster) ||
@@ -230,7 +230,7 @@ static LogicalResult fillShardingOption(Operation *op,
if (i == loopIdx)
continue;
- for (int32_t axis : meshAxes) {
+ for (MeshAxis axis : meshAxes) {
if (llvm::is_contained(shardingOption.shardingArray[i], axis)) {
LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes "
<< axis << " duplicate");
@@ -260,7 +260,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
unsigned numOperands = op->getNumOperands();
shardingOption.shardingArray.resize(loopTypes.size());
- llvm::SmallVector<int32_t> partialMeshAxes;
+ llvm::SmallVector<MeshAxis> partialMeshAxes;
Partial partialType;
llvm::SmallSet<unsigned, 4> visitedLoopIndices;
bool anyShardingInResultsOrOperands = false;
@@ -277,7 +277,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
// shardingOption[index]
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
AffineExpr expr = std::get<0>(it);
- ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+ ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
auto dim = cast<AffineDimExpr>(expr);
unsigned index = dim.getPosition();
visitedLoopIndices.insert(index);
@@ -288,7 +288,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
// Handle the partial axes: at this stage, the exact loop index/indices
// cannot be decided because there could be multiple reduction loops.
- ArrayRef<int32_t> partialAxes = shardAttr.getPartialAxes();
+ ArrayRef<MeshAxis> partialAxes = shardAttr.getPartialAxes();
if (!partialAxes.empty()) {
if (!partialMeshAxes.empty())
return op->emitOpError() << "at most one result with partial axes is "
@@ -321,7 +321,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
// then the operands with multiple loop indices.
for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
AffineExpr expr = std::get<0>(it);
- ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+ ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef();
FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
checkOperandAffineExpr(expr, numDims);
if (failed(loopIndices))
@@ -362,7 +362,7 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
if (!partialMeshAxes.empty()) {
bool anyNonEmptyReductionLoop = llvm::any_of(
llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
- SmallVector<int32_t> &subArray = it.value();
+ SmallVector<MeshAxis> &subArray = it.value();
int64_t idx = it.index();
return isReductionLoop(loopTypes[idx]) && !subArray.empty();
});
@@ -406,8 +406,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
return success();
auto resultType = result.getType().cast<RankedTensorType>();
- SmallVector<SmallVector<int32_t>> splitAxes(resultType.getRank());
- SmallVector<int32_t> partialAxes;
+ SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
+ SmallVector<MeshAxis> partialAxes;
// process the split axes
for (auto it : llvm::enumerate(map.getResults())) {
@@ -431,7 +431,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
assert(partialType == curPartialType &&
"Only one reduction type is supported");
partialType = curPartialType;
- const SmallVector<int32_t> &axis = std::get<1>(it);
+ const SmallVector<MeshAxis> &axis = std::get<1>(it);
partialAxes.append(axis);
}
}
@@ -459,7 +459,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
return success();
Value operand = opOperand.get();
auto operandType = operand.getType().cast<RankedTensorType>();
- SmallVector<SmallVector<int32_t>> splitAxes(operandType.getRank());
+ SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
unsigned numDims = map.getNumDims();
for (auto it : llvm::enumerate(map.getResults())) {
int64_t idx = it.index();
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 8d7e89662131a0..37b86535959652 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -147,7 +147,7 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
.getResult()
.cast<TypedValue<ShapedType>>();
- llvm::SmallVector<int32_t> remainingPartialAxes;
+ llvm::SmallVector<MeshAxis> remainingPartialAxes;
llvm::copy_if(sourceShardingPartialAxesSet,
std::back_inserter(allReduceMeshAxes),
[&targetShardingPartialAxesSet](Axis a) {
@@ -163,17 +163,17 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
static MeshShardingAttr
targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
- SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ SmallVector<MeshAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
splitTensorAxis) {
- targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+ targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
}
auto targetSplitAxes =
llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
targetSplitAxes.push_back(splitMeshAxis);
targetShardingSplitAxes[splitTensorAxis] =
- DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+ MeshAxesAttr::get(ctx, targetSplitAxes);
return MeshShardingAttr::get(
ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
@@ -356,7 +356,7 @@ static MeshShardingAttr
targetShardingInUnsplitLastAxis(MLIRContext *ctx,
MeshShardingAttr sourceSharding,
int64_t splitTensorAxis) {
- SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ SmallVector<MeshAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
splitTensorAxis);
@@ -365,7 +365,7 @@ targetShardingInUnsplitLastAxis(MLIRContext *ctx,
targetSplitAxes.pop_back();
targetShardingSplitAxes[splitTensorAxis] =
- DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+ MeshAxesAttr::get(ctx, targetSplitAxes);
return MeshShardingAttr::get(
ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
@@ -475,11 +475,11 @@ static MeshShardingAttr
targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
int64_t sourceTensorAxis,
int64_t targetTensorAxis) {
- SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ SmallVector<MeshAxesAttr> targetShardingSplitAxes =
llvm::to_vector(sourceSharding.getSplitAxes());
while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
targetTensorAxis) {
- targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+ targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {}));
}
auto sourceSplitAxes =
@@ -488,13 +488,13 @@ targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
auto meshAxis = sourceSplitAxes.back();
sourceSplitAxes.pop_back();
targetShardingSplitAxes[sourceTensorAxis] =
- DenseI32ArrayAttr::get(ctx, sourceSplitAxes);
+ MeshAxesAttr::get(ctx, sourceSplitAxes);
auto targetSplitAxes =
llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
targetSplitAxes.push_back(meshAxis);
targetShardingSplitAxes[targetTensorAxis] =
- DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+ MeshAxesAttr::get(ctx, targetSplitAxes);
return MeshShardingAttr::get(
ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
More information about the Mlir-commits
mailing list