[Mlir-commits] [mlir] [mlir][mesh] Insert resharding during sharding propagation (PR #84514)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 8 08:21:43 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Boian Petkantchin (sogartar)
<details>
<summary>Changes</summary>
If there are conflicts between the sharding annotations of some op, insert resharding.
Make the Spmdization pass more forgiving to allow for more than 2 chained `mesh.shard` ops.
Implement `getReductionLoopIteratorKinds` in ShardingInterface for linalg ops.
---
Patch is 20.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84514.diff
9 Files Affected:
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+25-3)
- (modified) mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp (+15)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+85-1)
- (modified) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+5-19)
- (modified) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+2)
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (-3)
- (added) mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir (+34)
- (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+35-5)
- (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+15)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 4569b77441c3f3..7a24c201a39a77 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -51,15 +51,26 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
// Is the same tensor replicated on all processes.
inline bool isFullReplication(MeshShardingAttr attr) {
- return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
+ return attr.getPartialAxes().empty() &&
+ llvm::all_of(attr.getSplitAxes(), [](MeshAxesAttr axes) {
+ return axes.asArrayRef().empty();
+ });
}
-inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
- SymbolTableCollection &symbolTableCollection) {
+inline mesh::MeshOp
+getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
+ SymbolTableCollection &symbolTableCollection) {
return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
op, meshSymbol);
}
+inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+ SymbolTableCollection &symbolTableCollection) {
+ mesh::MeshOp meshOp = getMeshOrNull(op, meshSymbol, symbolTableCollection);
+ assert(meshOp);
+ return meshOp;
+}
+
// Get the corresponding mesh op using the standard attribute nomenclature.
template <typename Op>
mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
@@ -128,6 +139,17 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
// `sharding` in that case must be null.
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
+// Insert shard op if there is not one that already has the same sharding.
+// May insert resharding if required.
+void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+ OpOperand &operand,
+ OpBuilder &builder);
+void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+ OpResult result, OpBuilder &builder);
+void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+ OpOperand &operand,
+ OpBuilder &builder);
+
} // namespace mesh
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 146e880765668b..632ee6e7b55852 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -36,6 +36,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include <iterator>
+#include <numeric>
#include <optional>
#include <utility>
@@ -279,6 +280,20 @@ struct StructuredOpShardingInterface
return res;
}
+ SmallVector<ReductionKind>
+ getReductionLoopIteratorKinds(Operation *op) const {
+ LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+ SmallVector<utils::IteratorType> iteratorTypes =
+ linalgOp.getIteratorTypesArray();
+ unsigned reductionItersCount = std::accumulate(
+ iteratorTypes.begin(), iteratorTypes.end(), 0,
+ [](unsigned count, utils::IteratorType iter) {
+ return count + (iter == utils::IteratorType::reduction);
+ });
+ mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
+ return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
+ }
+
LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshShardingAttr> operandShardings,
ArrayRef<MeshShardingAttr> resultShardings,
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 03f11ad1f94965..343ec3ce7e317a 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
@@ -28,6 +29,7 @@
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
#include <algorithm>
#include <functional>
#include <iterator>
@@ -99,7 +101,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
FlatSymbolRefAttr meshSymbol,
SymbolTableCollection &symbolTable) {
- mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
+ mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
if (!mesh) {
return op->emitError() << "Undefined required mesh symbol \""
<< meshSymbol.getValue() << "\".";
@@ -178,6 +180,88 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
return type;
}
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+ OpOperand &operand,
+ OpBuilder &builder) {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ Value operandValue = operand.get();
+ Operation *operandOp = operand.getOwner();
+ builder.setInsertionPointAfterValue(operandValue);
+ ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
+ if (shardOp && shardOp.getShard() == sharding &&
+ !shardOp.getAnnotateForUsers()) {
+ // No need for anything the correct sharding is already set.
+ return;
+ }
+
+ auto newShardOp =
+ builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+ /*annotate_for_users*/ false);
+ IRRewriter rewriter(builder);
+ rewriter.replaceUsesWithIf(
+ operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
+ return use.getOwner() == operandOp && use.get() == operandValue;
+ });
+
+ if (!shardOp || shardOp.getAnnotateForUsers()) {
+ return;
+ }
+
+ auto newShardOp2 = builder.create<ShardOp>(
+ operandValue.getLoc(), newShardOp, sharding, /*annotate_for_users*/ true);
+ rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+}
+
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+ OpResult result,
+ OpBuilder &builder) {
+ for (auto &use : llvm::make_early_inc_range(result.getUses())) {
+ maybeInsertTargetShardingAnnotation(sharding, use, builder);
+ }
+}
+
+void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+ OpOperand &operand,
+ OpBuilder &builder) {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ Value operandValue = operand.get();
+ Operation *operandOp = operand.getOwner();
+ Operation *operandSrcOp = operandValue.getDefiningOp();
+ bool isBlockArg = !operandSrcOp;
+ ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
+
+ if (shardOp && shardOp.getShard() == sharding &&
+ shardOp.getAnnotateForUsers()) {
+ // No need for anything the correct sharding is already set.
+ return;
+ }
+
+ builder.setInsertionPoint(operandOp);
+ auto newShardOp =
+ builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+ /*annotate_for_users*/ true);
+ IRRewriter rewriter(builder);
+ rewriter.replaceUsesWithIf(
+ operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
+ return use.getOwner() == operandOp && use.get() == operandValue;
+ });
+
+ if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
+ // No need for resharding.
+ return;
+ }
+
+ builder.setInsertionPoint(newShardOp);
+ auto newPreceedingShardOp =
+ builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+ /*annotate_for_users*/ false);
+ rewriter.replaceUsesWithIf(newShardOp.getOperand(), newPreceedingShardOp,
+ [&newShardOp](OpOperand &use) {
+ return use.getOwner() ==
+ newShardOp.getOperation();
+ });
+}
+
//===----------------------------------------------------------------------===//
// mesh.mesh op
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 9acee5aa8d8604..4ba61b46b6e08e 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -399,11 +399,6 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
AffineMap map,
ArrayRef<utils::IteratorType> loopTypes,
ArrayRef<ReductionKind> reductionLoopKinds) {
- FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
- getMeshShardingAttr(result);
- if (succeeded(maybeSharding) && !maybeSharding->first)
- return success();
-
auto resultType = result.getType().cast<RankedTensorType>();
SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
SmallVector<MeshAxis> partialAxes;
@@ -440,11 +435,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
removeTrailingEmptySubArray(splitAxes);
MeshShardingAttr shardAttr = MeshShardingAttr::get(
b.getContext(), shardingOption.mesh, splitAxes, partialAxes, partialType);
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointAfterValue(result);
- auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result,
- shardAttr, /*annotate_for_users*/ false);
- result.replaceAllUsesExcept(shardOp, shardOp);
+ maybeInsertTargetShardingAnnotation(shardAttr, result, b);
+
return success();
}
@@ -453,11 +445,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
const ShardingOption &shardingOption,
AffineMap map) {
- auto maybeShardingAttr = getMeshShardingAttr(opOperand);
- if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
- return success();
- Value operand = opOperand.get();
- auto operandType = operand.getType().cast<RankedTensorType>();
+ Value operandValue = opOperand.get();
+ auto operandType = operandValue.getType().cast<RankedTensorType>();
SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
unsigned numDims = map.getNumDims();
for (auto it : llvm::enumerate(map.getResults())) {
@@ -486,10 +475,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
MeshShardingAttr shardAttr =
MeshShardingAttr::get(b.getContext(), shardingOption.mesh, splitAxes);
OpBuilder::InsertionGuard guard(b);
- b.setInsertionPoint(opOperand.getOwner());
- auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand,
- shardAttr, true);
- opOperand.set(shardOp);
+ maybeInsertSourceShardingAnnotation(shardAttr, opOperand, b);
return success();
}
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 29320f1e339f86..b3ffe4a219254d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
@@ -199,6 +200,7 @@ struct ShardingPropagation
LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
<< funcOp << "\n");
+ LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
// 2. propagate in original order
for (Operation &op : llvm::make_early_inc_range(block))
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index e4868435135ed1..cc60a3d4826658 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -499,8 +499,6 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
ShardOp target,
TypedValue<ShapedType> sourceShardValue) {
- assert(!source.getAnnotateForUsers());
- assert(target.getAnnotateForUsers());
assert(source.getResult() == target.getOperand());
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
return reshard(
@@ -635,7 +633,6 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
} else {
// Insert resharding.
- assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
TypedValue<ShapedType> srcSpmdValue =
spmdizationMap.lookup(srcShardOp.getOperand())
.cast<TypedValue<ShapedType>>();
diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
new file mode 100644
index 00000000000000..59fd548dc2ef2c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt \
+// RUN: --verify-each \
+// RUN: --pass-pipeline="builtin.module(func.func(sharding-propagation))" \
+// RUN: %s | FileCheck %s
+
+mesh.mesh @mesh_2_2(shape = 2)
+
+// CHECK-LABEL: func @matmul_shard_prallel_axis
+func.func @matmul_shard_prallel_axis(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>,
+ %arg0 : tensor<2x3xf32>,
+ // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>,
+ %arg1 : tensor<3x2xf32>,
+ // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32>
+ %out_dps: tensor<2x2xf32>
+) -> tensor<2x2xf32> {
+ // CHECK: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[}}[0]]> : tensor<2x3xf32>
+ // CHECK: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to <@mesh_2, {{\[}}[0]]> annotate_for_users : tensor<2x3xf32>
+ // CHECK: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
+ // CHECK: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to <@mesh_2, {{\[}}[0]]> annotate_for_users : tensor<2x2xf32>
+ %arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
+
+ // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+ // CHECK-SAME: outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+ %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+ outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+
+ // CHECK: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to <@mesh_2, {{\[}}[0]]> : tensor<2x2xf32>
+ // CHECK: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to <@mesh_2, {{\[}}[]]> annotate_for_users : tensor<2x2xf32>
+ %res_sharded = mesh.shard %res to <@mesh_2, [[]]> annotate_for_users : tensor<2x2xf32>
+
+ // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32>
+ return %res_sharded : tensor<2x2xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 270787ab518831..f419b709df9f56 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,5 +1,6 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s
+mesh.mesh @mesh_2(shape = 2)
mesh.mesh @mesh_1d(shape = ?)
mesh.mesh @mesh_2d(shape = 2x4)
mesh.mesh @mesh_3d(shape = ?x?x?)
@@ -73,12 +74,11 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor
// CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]]
// CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
%1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
- // CHECK-NEXT: %[[V7:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
- // CHECK-NEXT: %[[V8:.*]] = tosa.negate %[[V7]]
- // CHECK-NEXT: %[[V9:.*]] = mesh.shard %[[V8]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V7:.*]] = tosa.negate %[[V4]]
+ // CHECK-NEXT: %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
%2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
%3 = mesh.shard %2 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
- // CHECK-NEXT: return %[[V6]], %[[V9]]
+ // CHECK-NEXT: return %[[V6]], %[[V8]]
return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
}
@@ -135,6 +135,36 @@ func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg
return %2 : tensor<2x16x32xf32>
}
+// CHECK-LABEL: func.func @resolve_conflicting_annotations
+func.func @resolve_conflicting_annotations(
+ // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>,
+ %arg0: tensor<2x3xf32>,
+ // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>,
+ %arg1: tensor<3x2xf32>,
+ // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32>
+ %out_dps: tensor<2x2xf32>
+// CHECK-SAME: ) -> tensor<2x2xf32> {
+) -> tensor<2x2xf32> {
+ // CHECK: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x3xf32>
+ // CHECK: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x3xf32>
+ // CHECK: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
+ // CHECK: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x2xf32>
+ %arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
+
+ // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+ // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+ %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+ outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+
+ // CHECK: %[[MATMUL_SHARDED1:.*]] = mesh.shard %[[MATMUL]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x2xf32>
+ // CHECK: %[[MATMUL_SHARDED2:.*]] = mesh.shard %[[MATMUL_SHARDED1]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x2xf32>
+ // CHECK: %[[MATMUL_SHARDED3:.*]] = mesh.shard %[[MATMUL_SHARDED2]] to <@mesh_2, {{\[\[}}]]> : tensor<2x2xf32>
+ %res_sharded = mesh.shard %res to <@mesh_2, [[]]> : tensor<2x2xf32>
+
+ // CHECK: return %[[MATMUL_SHARDED3]] : tensor<2x2xf32>
+ return %res_sharded : tensor<2x2xf32>
+}
+
// https://arxiv.org/abs/2211.05102 Figure 2(a)
// CHECK-LABEL: func.func @mlp_1d_weight_stationary
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 2df247aba35155..d7a1e2fd9d2790 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -16,6 +16,21 @@ func.func @full_replication(
return %1 : tensor<2xi8>
}
+// CHECK-LABEL: func @sharding_triplet
+func.func @sharding_triplet(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32>
+ %arg0: tensor<2xf32>
+// CHECK-SAME: ) -> tensor<2xf32> {
+) -> tensor<2xf32> {
+ // CHECK: %[[AL...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/84514
More information about the Mlir-commits
mailing list