[Mlir-commits] [mlir] [mlir][mesh] Insert resharding during sharding propagation (PR #84514)
Boian Petkantchin
llvmlistbot at llvm.org
Mon May 20 15:14:35 PDT 2024
https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/84514
>From 2edfda9f77ce690e77040efd5a38fe5efd910f20 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 1 Mar 2024 07:03:02 -0800
Subject: [PATCH 1/2] [mlir][mesh] Insert resharding during sharding
propagation
If there are conflicts between the sharding annotations of some op, insert
resharding.
Make the Spmdization pass more forgiving to allow for more than 2 chained
`mesh.shard` ops.
Implement `getReductionLoopIteratorKinds` in ShardingInterface for linalg ops.
---
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 28 +++++-
.../Transforms/MeshShardingInterfaceImpl.cpp | 15 ++++
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 86 ++++++++++++++++++-
.../Mesh/Interfaces/ShardingInterface.cpp | 24 ++----
.../Mesh/Transforms/ShardingPropagation.cpp | 2 +
.../Dialect/Mesh/Transforms/Spmdization.cpp | 3 -
.../Linalg/mesh-sharding-propagation.mlir | 34 ++++++++
.../Dialect/Mesh/sharding-propagation.mlir | 40 +++++++--
mlir/test/Dialect/Mesh/spmdization.mlir | 15 ++++
9 files changed, 216 insertions(+), 31 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 4569b77441c3f..7a24c201a39a7 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -51,15 +51,26 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
// Is the same tensor replicated on all processes.
inline bool isFullReplication(MeshShardingAttr attr) {
- return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
+ return attr.getPartialAxes().empty() &&
+ llvm::all_of(attr.getSplitAxes(), [](MeshAxesAttr axes) {
+ return axes.asArrayRef().empty();
+ });
}
-inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
- SymbolTableCollection &symbolTableCollection) {
+inline mesh::MeshOp
+getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
+ SymbolTableCollection &symbolTableCollection) {
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
op, meshSymbol);
}
+inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+ SymbolTableCollection &symbolTableCollection) {
+ mesh::MeshOp meshOp = getMeshOrNull(op, meshSymbol, symbolTableCollection);
+ assert(meshOp);
+ return meshOp;
+}
+
// Get the corresponding mesh op using the standard attribute nomenclature.
template <typename Op>
mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
@@ -128,6 +139,17 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
// `sharding` in that case must be null.
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
+// Insert shard op if there is not one that already has the same sharding.
+// May insert resharding if required.
+void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+ OpOperand &operand,
+ OpBuilder &builder);
+void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+ OpResult result, OpBuilder &builder);
+void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+ OpOperand &operand,
+ OpBuilder &builder);
+
} // namespace mesh
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 146e880765668..632ee6e7b5585 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -36,6 +36,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <iterator>
+#include <numeric>
#include <optional>
#include <utility>
@@ -279,6 +280,20 @@ struct StructuredOpShardingInterface
return res;
}
+ SmallVector<ReductionKind>
+ getReductionLoopIteratorKinds(Operation *op) const {
+ LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+ SmallVector<utils::IteratorType> iteratorTypes =
+ linalgOp.getIteratorTypesArray();
+ unsigned reductionItersCount = std::accumulate(
+ iteratorTypes.begin(), iteratorTypes.end(), 0,
+ [](unsigned count, utils::IteratorType iter) {
+ return count + (iter == utils::IteratorType::reduction);
+ });
+ mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
+ return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
+ }
+
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshShardingAttr> operandShardings,
ArrayRef<MeshShardingAttr> resultShardings,
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 03f11ad1f9496..343ec3ce7e317 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
@@ -28,6 +29,7 @@
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
#include <algorithm>
#include <functional>
#include <iterator>
@@ -99,7 +101,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTable) {
- mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
+ mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
if (!mesh) {
return op->emitError() << "Undefined required mesh symbol \""
<< meshSymbol.getValue() << "\".";
@@ -178,6 +180,88 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
return type;
}
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+ OpOperand &operand,
+ OpBuilder &builder) {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ Value operandValue = operand.get();
+ Operation *operandOp = operand.getOwner();
+ builder.setInsertionPointAfterValue(operandValue);
+ ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
+ if (shardOp && shardOp.getShard() == sharding &&
+ !shardOp.getAnnotateForUsers()) {
+ // No need for anything the correct sharding is already set.
+ return;
+ }
+
+ auto newShardOp =
+ builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+ /*annotate_for_users*/ false);
+ IRRewriter rewriter(builder);
+ rewriter.replaceUsesWithIf(
+ operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
+ return use.getOwner() == operandOp && use.get() == operandValue;
+ });
+
+ if (!shardOp || shardOp.getAnnotateForUsers()) {
+ return;
+ }
+
+ auto newShardOp2 = builder.create<ShardOp>(
+ operandValue.getLoc(), newShardOp, sharding, /*annotate_for_users*/ true);
+ rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+}
+
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+ OpResult result,
+ OpBuilder &builder) {
+ for (auto &use : llvm::make_early_inc_range(result.getUses())) {
+ maybeInsertTargetShardingAnnotation(sharding, use, builder);
+ }
+}
+
+void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+ OpOperand &operand,
+ OpBuilder &builder) {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ Value operandValue = operand.get();
+ Operation *operandOp = operand.getOwner();
+ Operation *operandSrcOp = operandValue.getDefiningOp();
+ bool isBlockArg = !operandSrcOp;
+ ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
+
+ if (shardOp && shardOp.getShard() == sharding &&
+ shardOp.getAnnotateForUsers()) {
+ // No need for anything the correct sharding is already set.
+ return;
+ }
+
+ builder.setInsertionPoint(operandOp);
+ auto newShardOp =
+ builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+ /*annotate_for_users*/ true);
+ IRRewriter rewriter(builder);
+ rewriter.replaceUsesWithIf(
+ operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
+ return use.getOwner() == operandOp && use.get() == operandValue;
+ });
+
+ if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
+ // No need for resharding.
+ return;
+ }
+
+ builder.setInsertionPoint(newShardOp);
+ auto newPreceedingShardOp =
+ builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+ /*annotate_for_users*/ false);
+ rewriter.replaceUsesWithIf(newShardOp.getOperand(), newPreceedingShardOp,
+ [&newShardOp](OpOperand &use) {
+ return use.getOwner() ==
+ newShardOp.getOperation();
+ });
+}
+
//===----------------------------------------------------------------------===//
// mesh.mesh op
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 9acee5aa8d860..4ba61b46b6e08 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -399,11 +399,6 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
AffineMap map,
ArrayRef<utils::IteratorType> loopTypes,
ArrayRef<ReductionKind> reductionLoopKinds) {
- FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
- getMeshShardingAttr(result);
- if (succeeded(maybeSharding) && !maybeSharding->first)
- return success();
-
auto resultType = result.getType().cast<RankedTensorType>();
SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
SmallVector<MeshAxis> partialAxes;
@@ -440,11 +435,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
removeTrailingEmptySubArray(splitAxes);
MeshShardingAttr shardAttr = MeshShardingAttr::get(
b.getContext(), shardingOption.mesh, splitAxes, partialAxes, partialType);
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointAfterValue(result);
- auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result,
- shardAttr, /*annotate_for_users*/ false);
- result.replaceAllUsesExcept(shardOp, shardOp);
+ maybeInsertTargetShardingAnnotation(shardAttr, result, b);
+
return success();
}
@@ -453,11 +445,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
const ShardingOption &shardingOption,
AffineMap map) {
- auto maybeShardingAttr = getMeshShardingAttr(opOperand);
- if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
- return success();
- Value operand = opOperand.get();
- auto operandType = operand.getType().cast<RankedTensorType>();
+ Value operandValue = opOperand.get();
+ auto operandType = operandValue.getType().cast<RankedTensorType>();
SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
unsigned numDims = map.getNumDims();
for (auto it : llvm::enumerate(map.getResults())) {
@@ -486,10 +475,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
MeshShardingAttr shardAttr =
MeshShardingAttr::get(b.getContext(), shardingOption.mesh, splitAxes);
OpBuilder::InsertionGuard guard(b);
- b.setInsertionPoint(opOperand.getOwner());
- auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand,
- shardAttr, true);
- opOperand.set(shardOp);
+ maybeInsertSourceShardingAnnotation(shardAttr, opOperand, b);
return success();
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 29320f1e339f8..b3ffe4a219254 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
@@ -199,6 +200,7 @@ struct ShardingPropagation
LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
<< funcOp << "\n");
+ LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
// 2. propagate in original order
for (Operation &op : llvm::make_early_inc_range(block))
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index e4868435135ed..cc60a3d482665 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -499,8 +499,6 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
ShardOp target,
TypedValue<ShapedType> sourceShardValue) {
- assert(!source.getAnnotateForUsers());
- assert(target.getAnnotateForUsers());
assert(source.getResult() == target.getOperand());
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
return reshard(
@@ -635,7 +633,6 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
} else {
// Insert resharding.
- assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
TypedValue<ShapedType> srcSpmdValue =
spmdizationMap.lookup(srcShardOp.getOperand())
.cast<TypedValue<ShapedType>>();
diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
new file mode 100644
index 0000000000000..59fd548dc2ef2
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt \
+// RUN: --verify-each \
+// RUN: --pass-pipeline="builtin.module(func.func(sharding-propagation))" \
+// RUN: %s | FileCheck %s
+
+mesh.mesh @mesh_2_2(shape = 2)
+
+// CHECK-LABEL: func @matmul_shard_prallel_axis
+func.func @matmul_shard_prallel_axis(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>,
+ %arg0 : tensor<2x3xf32>,
+ // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>,
+ %arg1 : tensor<3x2xf32>,
+ // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32>
+ %out_dps: tensor<2x2xf32>
+) -> tensor<2x2xf32> {
+ // CHECK: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[}}[0]]> : tensor<2x3xf32>
+ // CHECK: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to <@mesh_2, {{\[}}[0]]> annotate_for_users : tensor<2x3xf32>
+ // CHECK: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
+ // CHECK: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to <@mesh_2, {{\[}}[0]]> annotate_for_users : tensor<2x2xf32>
+ %arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
+
+ // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+ // CHECK-SAME: outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+ %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+ outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+
+ // CHECK: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to <@mesh_2, {{\[}}[0]]> : tensor<2x2xf32>
+ // CHECK: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to <@mesh_2, {{\[}}[]]> annotate_for_users : tensor<2x2xf32>
+ %res_sharded = mesh.shard %res to <@mesh_2, [[]]> annotate_for_users : tensor<2x2xf32>
+
+ // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32>
+ return %res_sharded : tensor<2x2xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 270787ab51883..f419b709df9f5 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,5 +1,6 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s
+mesh.mesh @mesh_2(shape = 2)
mesh.mesh @mesh_1d(shape = ?)
mesh.mesh @mesh_2d(shape = 2x4)
mesh.mesh @mesh_3d(shape = ?x?x?)
@@ -73,12 +74,11 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor
// CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]]
// CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
%1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V7:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V8:.*]] = tosa.negate %[[V7]]
- // CHECK-NEXT: %[[V9:.*]] = mesh.shard %[[V8]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]]
+ // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
%2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
%3 = mesh.shard %2 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V6]], %[[V9]]
+ // CHECK-NEXT: return %[[V6]], %[[V8]]
return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
}
@@ -135,6 +135,36 @@ func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg
return %2 : tensor<2x16x32xf32>
}
+// CHECK-LABEL: func.func @resolve_conflicting_annotations
+func.func @resolve_conflicting_annotations(
+ // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>,
+ %arg0: tensor<2x3xf32>,
+ // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>,
+ %arg1: tensor<3x2xf32>,
+ // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32>
+ %out_dps: tensor<2x2xf32>
+// CHECK-SAME: ) -> tensor<2x2xf32> {
+) -> tensor<2x2xf32> {
+ // CHECK: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x3xf32>
+ // CHECK: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x3xf32>
+ // CHECK: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
+ // CHECK: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x2xf32>
+ %arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
+
+ // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+ // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+ %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+ outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+
+ // CHECK: %[[MATMUL_SHARDED1:.*]] = mesh.shard %[[MATMUL]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x2xf32>
+ // CHECK: %[[MATMUL_SHARDED2:.*]] = mesh.shard %[[MATMUL_SHARDED1]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x2xf32>
+ // CHECK: %[[MATMUL_SHARDED3:.*]] = mesh.shard %[[MATMUL_SHARDED2]] to <@mesh_2, {{\[\[}}]]> : tensor<2x2xf32>
+ %res_sharded = mesh.shard %res to <@mesh_2, [[]]> : tensor<2x2xf32>
+
+ // CHECK: return %[[MATMUL_SHARDED3]] : tensor<2x2xf32>
+ return %res_sharded : tensor<2x2xf32>
+}
+
// https://arxiv.org/abs/2211.05102 Figure 2(a)
// CHECK-LABEL: func.func @mlp_1d_weight_stationary
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 2df247aba3515..d7a1e2fd9d279 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -16,6 +16,21 @@ func.func @full_replication(
return %1 : tensor<2xi8>
}
+// CHECK-LABEL: func @sharding_triplet
+func.func @sharding_triplet(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32>
+ %arg0: tensor<2xf32>
+// CHECK-SAME: ) -> tensor<2xf32> {
+) -> tensor<2xf32> {
+ // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32>
+ %sharding_annotated = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xf32>
+ %sharding_annotated_0 = mesh.shard %sharding_annotated to <@mesh_1d, [[0]]> annotate_for_users : tensor<2xf32>
+ %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to <@mesh_1d, [[]]> : tensor<2xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<2xf32>
+ return %sharding_annotated_1 : tensor<2xf32>
+}
+
+
// CHECK-LABEL: func @move_split_axis
func.func @move_split_axis(
// CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8>
>From f32e1194bcd930e7549f892ef45cf99c551feee2 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Wed, 15 May 2024 19:37:45 -0500
Subject: [PATCH 2/2] Address Chengji's PR comments
---
mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td | 2 +
.../Mesh/Interfaces/ShardingInterface.h | 9 +
.../Mesh/Interfaces/ShardingInterface.td | 25 +-
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 8 +
.../Mesh/Interfaces/ShardingInterface.cpp | 101 ++++++--
.../Mesh/Transforms/ShardingPropagation.cpp | 229 ++++++++++++++++--
.../Dialect/Mesh/sharding-propagation.mlir | 10 +-
7 files changed, 333 insertions(+), 51 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 9d9b5892e1a51..3a85bf2d552f3 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -151,7 +151,9 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let extraClassDeclaration = [{
bool operator==(::mlir::Attribute rhs) const;
+ bool operator!=(::mlir::Attribute rhs) const;
bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
+ bool operator!=(::mlir::mesh::MeshShardingAttr rhs) const;
}];
let genVerifyDecl = 1;
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index c47a7ddd3f9cc..216d7e10296df 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -37,6 +37,11 @@ struct ShardingOption {
ShardingOption() = default;
ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr mesh)
: shardingArray(std::move(shardingArray)), mesh(mesh) {}
+ static ShardingOption makeEmpty() {
+ auto res = ShardingOption();
+ res.empty = true;
+ return res;
+ }
};
// This method retrieves the 'MeshShardingAttr' attribute from a given operation
@@ -56,6 +61,10 @@ defaultGetShardingOption(Operation *op,
ArrayRef<MeshShardingAttr> operandShardings,
ArrayRef<MeshShardingAttr> resultShardings);
+FailureOr<SmallVector<MeshShardingAttr>>
+defaultGetShardingAnnotations(Operation *op,
+ const ShardingOption &shardingOption);
+
LogicalResult
defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
const ShardingOption &shardingOption);
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
index 1f75135f42882..47a74f619f56c 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -75,8 +75,11 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
InterfaceMethod<
/*desc=*/[{
Given that certain operands or results of the operation may have
- sharding annotations, this method leverages this information to deduce
- how the operation should be sharded.
+ sharding annotations, this method leverages this information to
+ deduce how the operation should be sharded.
+ The passed sharding may be incomplete, this gives freedom for the
+ op to select the most appropriate shardings for all the operands
+ and results and the op itself.
}],
/*retTy=*/"FailureOr<ShardingOption>",
/*methodName=*/"getShardingOption",
@@ -90,6 +93,24 @@ def ShardingInterface : OpInterface<"ShardingInterface"> {
$_op.getOperation(), operandShardings, resultShardings);
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Based on a given ShardingOption, get the operand and result
+ operations for the operands and results sharding annotations.
+ This is what shardings the operands and results need to have in order
+ to shard the op according to shardingOption.
+ }],
+ /*retTy=*/"FailureOr<SmallVector<MeshShardingAttr>>",
+ /*methodName=*/"getShardingAnnotations",
+ /*args=*/(ins
+ "const ShardingOption &":$shardingOption
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return detail::defaultGetShardingAnnotations(
+ $_op.getOperation(), shardingOption);
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Based on a given ShardingOption, this method adds `mesh.shard`
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 343ec3ce7e317..42f7a4268ea83 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -369,6 +369,10 @@ bool MeshShardingAttr::operator==(Attribute rhs) const {
return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
}
+bool MeshShardingAttr::operator!=(Attribute rhs) const {
+ return !(*this == rhs);
+}
+
bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
if (getMesh() != rhs.getMesh() || getPartialAxes() != rhs.getPartialAxes()) {
return false;
@@ -394,6 +398,10 @@ bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
std::mem_fn(&MeshAxesAttr::empty));
}
+bool MeshShardingAttr::operator!=(MeshShardingAttr rhs) const {
+ return !(*this == rhs);
+}
+
//===----------------------------------------------------------------------===//
// mesh.shard op
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 4ba61b46b6e08..41f23e6455e05 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
@@ -388,17 +389,11 @@ FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(
return shardingOption;
}
-//===----------------------------------------------------------------------===//
-// detail::defaultAddShardingAnnotations
-//===----------------------------------------------------------------------===//
-
-// To add a `mesh.shard` op for the given result, based on the details provided
-// in `shardingOption`, `map`, and `loopTypes`.
-static LogicalResult addShardOp(OpBuilder &b, OpResult result,
- const ShardingOption &shardingOption,
- AffineMap map,
- ArrayRef<utils::IteratorType> loopTypes,
- ArrayRef<ReductionKind> reductionLoopKinds) {
+// Get the sharding attributed for the given result and sharding option.
+MeshShardingAttr
+getShardingAttribute(OpResult result, const ShardingOption &shardingOption,
+ AffineMap map, ArrayRef<utils::IteratorType> loopTypes,
+ ArrayRef<ReductionKind> reductionLoopKinds) {
auto resultType = result.getType().cast<RankedTensorType>();
SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
SmallVector<MeshAxis> partialAxes;
@@ -433,18 +428,13 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
}
removeTrailingEmptySubArray(splitAxes);
- MeshShardingAttr shardAttr = MeshShardingAttr::get(
- b.getContext(), shardingOption.mesh, splitAxes, partialAxes, partialType);
- maybeInsertTargetShardingAnnotation(shardAttr, result, b);
-
- return success();
+ return MeshShardingAttr::get(result.getContext(), shardingOption.mesh,
+ splitAxes, partialAxes, partialType);
}
-// To add a `mesh.shard` op for the given operand, based on the details provided
-// in `shardingOption`, `map`, and `loopTypes`.
-static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
- const ShardingOption &shardingOption,
- AffineMap map) {
+static FailureOr<MeshShardingAttr>
+getShardingAttribute(OpOperand &opOperand, const ShardingOption &shardingOption,
+ AffineMap map) {
Value operandValue = opOperand.get();
auto operandType = operandValue.getType().cast<RankedTensorType>();
SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
@@ -472,16 +462,79 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
}
removeTrailingEmptySubArray(splitAxes);
- MeshShardingAttr shardAttr =
- MeshShardingAttr::get(b.getContext(), shardingOption.mesh, splitAxes);
+ return MeshShardingAttr::get(opOperand.get().getContext(),
+ shardingOption.mesh, splitAxes);
+}
+
+FailureOr<SmallVector<MeshShardingAttr>>
+mesh::detail::defaultGetShardingAnnotations(
+ Operation *op, const ShardingOption &shardingOption) {
+ SmallVector<MeshShardingAttr> res;
+
+ ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
+ SmallVector<utils::IteratorType> loopTypes =
+ shardingOp.getLoopIteratorTypes();
+ SmallVector<ReductionKind> reductionKinds =
+ shardingOp.getReductionLoopIteratorKinds();
+ SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
+ unsigned numOperands = op->getNumOperands();
+
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ FailureOr<MeshShardingAttr> shardingAttr = getShardingAttribute(
+ opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
+ if (failed(shardingAttr))
+ return failure();
+ res.push_back(*shardingAttr);
+ }
+
+ for (OpResult result : op->getResults()) {
+ res.push_back(getShardingAttribute(
+ result, shardingOption, maps[numOperands + result.getResultNumber()],
+ loopTypes, reductionKinds));
+ }
+
+ return res;
+}
+
+//===----------------------------------------------------------------------===//
+// detail::defaultAddShardingAnnotations
+//===----------------------------------------------------------------------===//
+
+// To add a `mesh.shard` op for the given result, based on the details provided
+// in `shardingOption`, `map`, and `loopTypes`.
+static LogicalResult addShardOp(OpBuilder &b, OpResult result,
+ const ShardingOption &shardingOption,
+ AffineMap map,
+ ArrayRef<utils::IteratorType> loopTypes,
+ ArrayRef<ReductionKind> reductionLoopKinds) {
+ MeshShardingAttr shardAttr = getShardingAttribute(
+ result, shardingOption, map, loopTypes, reductionLoopKinds);
+ maybeInsertTargetShardingAnnotation(shardAttr, result, b);
+
+ return success();
+}
+
+// To add a `mesh.shard` op for the given operand, based on the details provided
+// in `shardingOption`, `map`, and `loopTypes`.
+static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
+ const ShardingOption &shardingOption,
+ AffineMap map) {
+
+ FailureOr<MeshShardingAttr> shardAttr =
+ getShardingAttribute(opOperand, shardingOption, map);
+ if (failed(shardAttr)) {
+ return failure();
+ }
OpBuilder::InsertionGuard guard(b);
- maybeInsertSourceShardingAnnotation(shardAttr, opOperand, b);
+ maybeInsertSourceShardingAnnotation(*shardAttr, opOperand, b);
return success();
}
LogicalResult mesh::detail::defaultAddShardingAnnotations(
Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
+ assert(!shardingOption.empty && shardingOption.mesh);
+
ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
SmallVector<utils::IteratorType> loopTypes =
shardingOp.getLoopIteratorTypes();
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index b3ffe4a219254..65d12a16085d3 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -15,7 +15,13 @@
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
#include <vector>
namespace mlir {
@@ -31,6 +37,70 @@ namespace mesh {
using namespace mlir;
using namespace mlir::mesh;
+enum class ReshardingRquirementKind {
+ NO_RESHARDING = 0,
+ NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS,
+ RESHARDING_FOR_EXPLICIT_ANNOTATIONS
+};
+
+#ifdef LLVM_DEBUG
+
+template <typename T>
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
+ const SmallVector<T> &vec);
+template <typename... Ts>
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
+ const std::tuple<Ts...> &t);
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
+ ReshardingRquirementKind v);
+
+template <typename Stream, typename Range>
+static Stream &printRange(Stream &stream, Range &&range) {
+ stream << "[";
+ llvm::for_each(range, [&stream](auto &v) {
+ stream << v;
+ stream << ", ";
+ });
+ return stream << "]";
+}
+
+template <typename T>
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
+ const SmallVector<T> &vec) {
+ return printRange(stream, vec);
+}
+
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
+ const ShardingOption &v) {
+ return stream << "{empty = " << v.empty << ", mesh" << v.mesh
+ << ", shardingArray = " << v.shardingArray << "}";
+}
+
+template <typename Stream, typename... Ts, size_t... Is>
+static Stream &printTuple(Stream &stream, std::tuple<Ts...> tuple,
+ std::index_sequence<Is...>) {
+ static_assert(sizeof...(Is) == sizeof...(Ts),
+ "Indices must have same number of elements as tuple types!");
+ static_assert(sizeof...(Ts) > 0, "Cannot insert empty tuple into stream.");
+
+ stream << "{";
+ ((stream << std::get<Is>(tuple) << ", "), ...);
+ return stream << "}";
+}
+
+template <typename... Ts>
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
+ const std::tuple<Ts...> &t) {
+ return printTuple(stream, t, std::index_sequence_for<Ts...>{});
+}
+
+static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream,
+ ReshardingRquirementKind v) {
+ return stream << static_cast<int>(v);
+}
+
+#endif // LLVM_DEBUG
+
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
@@ -78,6 +148,138 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshShardingAttr> mustShardings,
return allShardingAttrs;
}
+// From all the sharding options return the one that is most compatible with
+// the sharding annotations of operands and results of the operation.
+// The order of preference is form highest to lowest:
+// 1. No resharding is required (all existing annotations are compatible).
+// 2. No resharding for operands/results that have annotation specifically
+// targeting this operation. This means
+// * operands that are the result of `mesh.shard` ops marked with
+// `annotate_for_users`.
+// * results that are annotated with `mesh.shard` ops without
+// `annotate_for_users`.
+// 3. All other cases. Resharding is required for operands/results with
+// annotation targeting explicitly this operation.
+// size_t preferredShardingOption(Operation *op, const
+// SmallVector<ShardingOption>& shardingOptions) {
+
+// }
+
+ReshardingRquirementKind getReshardingRquirementKind(
+ Operation *op,
+ const SmallVector<MeshShardingAttr> &operandAndResultShardings) {
+ ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING;
+
+ size_t operandsCount = op->getOperands().size();
+ auto operandShardings =
+ llvm::make_range(operandAndResultShardings.begin(),
+ operandAndResultShardings.begin() + operandsCount);
+ auto resultShardings =
+ llvm::make_range(operandAndResultShardings.begin() + operandsCount,
+ operandAndResultShardings.end());
+
+ for (auto [operand, sharding] :
+ llvm::zip_equal(op->getOperands(), operandShardings)) {
+ ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp());
+ if (!shardOp) {
+ continue;
+ }
+ bool needsResharding = shardOp.getShardAttr() != sharding;
+ bool isExplicitAnnotationForThisOp = shardOp.getAnnotateForUsers();
+ if (needsResharding) {
+ if (isExplicitAnnotationForThisOp) {
+ // This is the worst case. No need to continue.
+ return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
+ }
+ res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
+ }
+ }
+
+ for (auto [result, sharding] :
+ llvm::zip_equal(op->getResults(), resultShardings)) {
+ for (auto user : result.getUsers()) {
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
+ if (!shardOp) {
+ continue;
+ }
+ bool needsResharding = shardOp.getShardAttr() != sharding;
+ bool isExplicitAnnotationForThisOp = !shardOp.getAnnotateForUsers();
+ if (needsResharding) {
+ if (isExplicitAnnotationForThisOp) {
+ // This is the worst case. No need to continue.
+ return ReshardingRquirementKind::RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
+ }
+ res = ReshardingRquirementKind::NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS;
+ }
+ }
+ }
+
+ return res;
+}
+
+static FailureOr<ShardingOption> selectShardingOption(
+ ShardingInterface shardingOp,
+ ArrayRef<SmallVector<MeshShardingAttr>> possibleOperandShardingAttrs,
+ ArrayRef<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs) {
+ SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>>
+ shardingOptionsAndReshardingRequirements;
+
+ for (ArrayRef<MeshShardingAttr> resultShardings :
+ possibleResultShardingAttrs) {
+ for (ArrayRef<MeshShardingAttr> operandShardings :
+ possibleOperandShardingAttrs) {
+ FailureOr<ShardingOption> shardingOption =
+ shardingOp.getShardingOption(operandShardings, resultShardings);
+ if (failed(shardingOption) || shardingOption->empty) {
+ continue;
+ }
+ // These shardings may not be the same as those in operandShardings and
+ // resultShardings.
+ // They may be missing some annotations.
+ // Whatever is returned by getShardingAnnotations is exactly what the op
+ // needs.
+ FailureOr<SmallVector<MeshShardingAttr>> operandAndResultShardings =
+ shardingOp.getShardingAnnotations(*shardingOption);
+ if (failed(operandAndResultShardings)) {
+ return failure();
+ }
+
+ LLVM_DEBUG(DBGS() << "operandAndResultShardings = "
+ << *operandAndResultShardings << "\n";);
+
+ ReshardingRquirementKind reshardingRquirement =
+ getReshardingRquirementKind(shardingOp, *operandAndResultShardings);
+ if (reshardingRquirement == ReshardingRquirementKind::NO_RESHARDING) {
+ // This is the best case. No need to go on.
+ return *shardingOption;
+ }
+
+ shardingOptionsAndReshardingRequirements.emplace_back(
+ std::move(*shardingOption), reshardingRquirement);
+ }
+ }
+
+ if (shardingOptionsAndReshardingRequirements.empty()) {
+ return ShardingOption::makeEmpty();
+ }
+
+ std::partial_sort(
+ shardingOptionsAndReshardingRequirements.begin(),
+ shardingOptionsAndReshardingRequirements.begin() + 1,
+ shardingOptionsAndReshardingRequirements.end(),
+ [](const std::tuple<ShardingOption, ReshardingRquirementKind> &a,
+ const std::tuple<ShardingOption, ReshardingRquirementKind> &b) {
+ return std::get<ReshardingRquirementKind>(a) <
+ std::get<ReshardingRquirementKind>(b);
+ });
+
+ LLVM_DEBUG(DBGS() << "shardingOptionsAndReshardingRequirements = "
+ << shardingOptionsAndReshardingRequirements << "\n";);
+
+ return std::get<ShardingOption>(
+ shardingOptionsAndReshardingRequirements.front());
+}
+
// For each operation that implements the ShardingInterface, infer the sharding
// option of the operation from its operands and/or results using the
// `getShardingOption` method. If the inferred sharding option is not empty, add
@@ -136,32 +338,21 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
SmallVector<SmallVector<MeshShardingAttr>> possibleResultShardingAttrs =
getOrderedPossibleShardingAttrs(resultMustShardings,
allowConflictsResultShardings);
- FailureOr<ShardingOption> finalShardingOption = failure();
- for (ArrayRef<MeshShardingAttr> resultShardings :
- possibleResultShardingAttrs) {
- if (succeeded(finalShardingOption))
- break;
- for (ArrayRef<MeshShardingAttr> operandShardings :
- possibleOperandShardingAttrs) {
- FailureOr<ShardingOption> shardingOption =
- shardingOp.getShardingOption(operandShardings, resultShardings);
- if (succeeded(shardingOption)) {
- finalShardingOption = shardingOption;
- break;
- }
- }
- }
+ FailureOr<ShardingOption> shardingOption = selectShardingOption(
+ shardingOp, possibleOperandShardingAttrs, possibleResultShardingAttrs);
- if (failed(finalShardingOption)) {
+ if (failed(shardingOption)) {
op->emitOpError() << "fail to get sharding option.";
return failure();
}
+
+ LLVM_DEBUG(DBGS() << "Selected sharding option: " << *shardingOption << "\n");
+
// sharding info is empty, return immediately
- if (finalShardingOption->empty)
+ if (shardingOption->empty)
return success();
- if (failed(
- shardingOp.addShardingAnnotations(builder, *finalShardingOption))) {
+ if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
op->emitOpError() << "fail to set sharding annotations.";
return failure();
}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index f419b709df9f5..11a80594adb79 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -146,9 +146,9 @@ func.func @resolve_conflicting_annotations(
// CHECK-SAME: ) -> tensor<2x2xf32> {
) -> tensor<2x2xf32> {
// CHECK: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x3xf32>
- // CHECK: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x3xf32>
+ // CHECK: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to <@mesh_2, {{\[}}]> annotate_for_users : tensor<2x3xf32>
// CHECK: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
- // CHECK: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x2xf32>
+ // CHECK: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to <@mesh_2, {{\[}}]> annotate_for_users : tensor<2x2xf32>
%arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
@@ -156,12 +156,10 @@ func.func @resolve_conflicting_annotations(
%res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
- // CHECK: %[[MATMUL_SHARDED1:.*]] = mesh.shard %[[MATMUL]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x2xf32>
- // CHECK: %[[MATMUL_SHARDED2:.*]] = mesh.shard %[[MATMUL_SHARDED1]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x2xf32>
- // CHECK: %[[MATMUL_SHARDED3:.*]] = mesh.shard %[[MATMUL_SHARDED2]] to <@mesh_2, {{\[\[}}]]> : tensor<2x2xf32>
+ // CHECK: %[[MATMUL_SHARDED1:.*]] = mesh.shard %[[MATMUL]] to <@mesh_2, {{\[\[}}]]> : tensor<2x2xf32>
%res_sharded = mesh.shard %res to <@mesh_2, [[]]> : tensor<2x2xf32>
- // CHECK: return %[[MATMUL_SHARDED3]] : tensor<2x2xf32>
+ // CHECK: return %[[MATMUL_SHARDED1]] : tensor<2x2xf32>
return %res_sharded : tensor<2x2xf32>
}
More information about the Mlir-commits
mailing list