[Mlir-commits] [mlir] [mlir][mesh] Insert resharding during sharding propagation (PR #84514)

Boian Petkantchin llvmlistbot at llvm.org
Fri Mar 8 08:21:13 PST 2024


https://github.com/sogartar created https://github.com/llvm/llvm-project/pull/84514

If there are conflicts between the sharding annotations of some op, insert resharding.
Make the Spmdization pass more forgiving to allow for more than 2 chained `mesh.shard` ops.

Implement `getReductionLoopIteratorKinds` in ShardingInterface for linalg ops.

>From 2edfda9f77ce690e77040efd5a38fe5efd910f20 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 1 Mar 2024 07:03:02 -0800
Subject: [PATCH] [mlir][mesh] Insert resharding during sharding propagation

If there are conflicts between the sharding annotations of some op, insert
resharding.
Make the Spmdization pass more forgiving to allow for more than 2 chained
`mesh.shard` ops.

Implement `getReductionLoopIteratorKinds` in ShardingInterface for linalg ops.
---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h   | 28 +++++-
 .../Transforms/MeshShardingInterfaceImpl.cpp  | 15 ++++
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp          | 86 ++++++++++++++++++-
 .../Mesh/Interfaces/ShardingInterface.cpp     | 24 ++----
 .../Mesh/Transforms/ShardingPropagation.cpp   |  2 +
 .../Dialect/Mesh/Transforms/Spmdization.cpp   |  3 -
 .../Linalg/mesh-sharding-propagation.mlir     | 34 ++++++++
 .../Dialect/Mesh/sharding-propagation.mlir    | 40 +++++++--
 mlir/test/Dialect/Mesh/spmdization.mlir       | 15 ++++
 9 files changed, 216 insertions(+), 31 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 4569b77441c3f3..7a24c201a39a77 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -51,15 +51,26 @@ void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
 
 // Is the same tensor replicated on all processes.
 inline bool isFullReplication(MeshShardingAttr attr) {
-  return attr.getPartialAxes().empty() && attr.getSplitAxes().empty();
+  return attr.getPartialAxes().empty() &&
+         llvm::all_of(attr.getSplitAxes(), [](MeshAxesAttr axes) {
+           return axes.asArrayRef().empty();
+         });
 }
 
-inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
-                            SymbolTableCollection &symbolTableCollection) {
+inline mesh::MeshOp
+getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol,
+              SymbolTableCollection &symbolTableCollection) {
   return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
       op, meshSymbol);
 }
 
+inline mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+                            SymbolTableCollection &symbolTableCollection) {
+  mesh::MeshOp meshOp = getMeshOrNull(op, meshSymbol, symbolTableCollection);
+  assert(meshOp);
+  return meshOp;
+}
+
 // Get the corresponding mesh op using the standard attribute nomenclature.
 template <typename Op>
 mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
@@ -128,6 +139,17 @@ ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
 // `sharding` in that case must be null.
 Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
 
+// Insert shard op if there is not one that already has the same sharding.
+// May insert resharding if required.
+void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+                                         OpOperand &operand,
+                                         OpBuilder &builder);
+void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+                                         OpResult result, OpBuilder &builder);
+void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+                                         OpOperand &operand,
+                                         OpBuilder &builder);
+
 } // namespace mesh
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
index 146e880765668b..632ee6e7b55852 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -36,6 +36,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include <iterator>
+#include <numeric>
 #include <optional>
 #include <utility>
 
@@ -279,6 +280,20 @@ struct StructuredOpShardingInterface
     return res;
   }
 
+  SmallVector<ReductionKind>
+  getReductionLoopIteratorKinds(Operation *op) const {
+    LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+    SmallVector<utils::IteratorType> iteratorTypes =
+        linalgOp.getIteratorTypesArray();
+    unsigned reductionItersCount = std::accumulate(
+        iteratorTypes.begin(), iteratorTypes.end(), 0,
+        [](unsigned count, utils::IteratorType iter) {
+          return count + (iter == utils::IteratorType::reduction);
+        });
+    mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
+    return SmallVector<ReductionKind>(reductionItersCount, reductionKind);
+  }
+
   LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
                         ArrayRef<MeshShardingAttr> operandShardings,
                         ArrayRef<MeshShardingAttr> resultShardings,
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 03f11ad1f94965..343ec3ce7e317a 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/Location.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
@@ -28,6 +29,7 @@
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
 #include <algorithm>
 #include <functional>
 #include <iterator>
@@ -99,7 +101,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
 static FailureOr<MeshOp> getMeshAndVerify(Operation *op,
                                           FlatSymbolRefAttr meshSymbol,
                                           SymbolTableCollection &symbolTable) {
-  mesh::MeshOp mesh = getMesh(op, meshSymbol, symbolTable);
+  mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
   if (!mesh) {
     return op->emitError() << "Undefined required mesh symbol \""
                            << meshSymbol.getValue() << "\".";
@@ -178,6 +180,88 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
   return type;
 }
 
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+                                                     OpOperand &operand,
+                                                     OpBuilder &builder) {
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  Value operandValue = operand.get();
+  Operation *operandOp = operand.getOwner();
+  builder.setInsertionPointAfterValue(operandValue);
+  ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
+  if (shardOp && shardOp.getShard() == sharding &&
+      !shardOp.getAnnotateForUsers()) {
+    // No need for anything the correct sharding is already set.
+    return;
+  }
+
+  auto newShardOp =
+      builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+                              /*annotate_for_users*/ false);
+  IRRewriter rewriter(builder);
+  rewriter.replaceUsesWithIf(
+      operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
+        return use.getOwner() == operandOp && use.get() == operandValue;
+      });
+
+  if (!shardOp || shardOp.getAnnotateForUsers()) {
+    return;
+  }
+
+  auto newShardOp2 = builder.create<ShardOp>(
+      operandValue.getLoc(), newShardOp, sharding, /*annotate_for_users*/ true);
+  rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+}
+
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding,
+                                                     OpResult result,
+                                                     OpBuilder &builder) {
+  for (auto &use : llvm::make_early_inc_range(result.getUses())) {
+    maybeInsertTargetShardingAnnotation(sharding, use, builder);
+  }
+}
+
+void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding,
+                                                     OpOperand &operand,
+                                                     OpBuilder &builder) {
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  Value operandValue = operand.get();
+  Operation *operandOp = operand.getOwner();
+  Operation *operandSrcOp = operandValue.getDefiningOp();
+  bool isBlockArg = !operandSrcOp;
+  ShardOp shardOp = dyn_cast_or_null<ShardOp>(operandSrcOp);
+
+  if (shardOp && shardOp.getShard() == sharding &&
+      shardOp.getAnnotateForUsers()) {
+    // No need for anything the correct sharding is already set.
+    return;
+  }
+
+  builder.setInsertionPoint(operandOp);
+  auto newShardOp =
+      builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+                              /*annotate_for_users*/ true);
+  IRRewriter rewriter(builder);
+  rewriter.replaceUsesWithIf(
+      operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
+        return use.getOwner() == operandOp && use.get() == operandValue;
+      });
+
+  if (isBlockArg || !shardOp || !shardOp.getAnnotateForUsers()) {
+    // No need for resharding.
+    return;
+  }
+
+  builder.setInsertionPoint(newShardOp);
+  auto newPreceedingShardOp =
+      builder.create<ShardOp>(operandValue.getLoc(), operandValue, sharding,
+                              /*annotate_for_users*/ false);
+  rewriter.replaceUsesWithIf(newShardOp.getOperand(), newPreceedingShardOp,
+                             [&newShardOp](OpOperand &use) {
+                               return use.getOwner() ==
+                                      newShardOp.getOperation();
+                             });
+}
+
 //===----------------------------------------------------------------------===//
 // mesh.mesh op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index 9acee5aa8d8604..4ba61b46b6e08e 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -399,11 +399,6 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
                                 AffineMap map,
                                 ArrayRef<utils::IteratorType> loopTypes,
                                 ArrayRef<ReductionKind> reductionLoopKinds) {
-  FailureOr<std::pair<bool, MeshShardingAttr>> maybeSharding =
-      getMeshShardingAttr(result);
-  if (succeeded(maybeSharding) && !maybeSharding->first)
-    return success();
-
   auto resultType = result.getType().cast<RankedTensorType>();
   SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
   SmallVector<MeshAxis> partialAxes;
@@ -440,11 +435,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
   removeTrailingEmptySubArray(splitAxes);
   MeshShardingAttr shardAttr = MeshShardingAttr::get(
       b.getContext(), shardingOption.mesh, splitAxes, partialAxes, partialType);
-  OpBuilder::InsertionGuard guard(b);
-  b.setInsertionPointAfterValue(result);
-  auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result,
-                                   shardAttr, /*annotate_for_users*/ false);
-  result.replaceAllUsesExcept(shardOp, shardOp);
+  maybeInsertTargetShardingAnnotation(shardAttr, result, b);
+
   return success();
 }
 
@@ -453,11 +445,8 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
 static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
                                 const ShardingOption &shardingOption,
                                 AffineMap map) {
-  auto maybeShardingAttr = getMeshShardingAttr(opOperand);
-  if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
-    return success();
-  Value operand = opOperand.get();
-  auto operandType = operand.getType().cast<RankedTensorType>();
+  Value operandValue = opOperand.get();
+  auto operandType = operandValue.getType().cast<RankedTensorType>();
   SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
   unsigned numDims = map.getNumDims();
   for (auto it : llvm::enumerate(map.getResults())) {
@@ -486,10 +475,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
   MeshShardingAttr shardAttr =
       MeshShardingAttr::get(b.getContext(), shardingOption.mesh, splitAxes);
   OpBuilder::InsertionGuard guard(b);
-  b.setInsertionPoint(opOperand.getOwner());
-  auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand,
-                                   shardAttr, true);
-  opOperand.set(shardOp);
+  maybeInsertSourceShardingAnnotation(shardAttr, opOperand, b);
 
   return success();
 }
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 29320f1e339f86..b3ffe4a219254d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "llvm/Support/Debug.h"
@@ -199,6 +200,7 @@ struct ShardingPropagation
 
     LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
                       << funcOp << "\n");
+    LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
 
     // 2. propagate in original order
     for (Operation &op : llvm::make_early_inc_range(block))
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index e4868435135ed1..cc60a3d4826658 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -499,8 +499,6 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
                                ShardOp target,
                                TypedValue<ShapedType> sourceShardValue) {
-  assert(!source.getAnnotateForUsers());
-  assert(target.getAnnotateForUsers());
   assert(source.getResult() == target.getOperand());
   ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
   return reshard(
@@ -635,7 +633,6 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
     targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
   } else {
     // Insert resharding.
-    assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
     TypedValue<ShapedType> srcSpmdValue =
         spmdizationMap.lookup(srcShardOp.getOperand())
             .cast<TypedValue<ShapedType>>();
diff --git a/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
new file mode 100644
index 00000000000000..59fd548dc2ef2c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/mesh-sharding-propagation.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt \
+// RUN:   --verify-each \
+// RUN:   --pass-pipeline="builtin.module(func.func(sharding-propagation))" \
+// RUN:   %s | FileCheck %s
+
+mesh.mesh @mesh_2_2(shape = 2)
+
+// CHECK-LABEL: func @matmul_shard_prallel_axis
+func.func @matmul_shard_prallel_axis(
+  // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<2x3xf32>,
+  %arg0 : tensor<2x3xf32>,
+  // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x2xf32>,
+  %arg1 : tensor<3x2xf32>,
+  // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<2x2xf32>
+  %out_dps: tensor<2x2xf32>
+) -> tensor<2x2xf32> {
+  // CHECK: %[[IN1_ANNOTATED_0:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[}}[0]]> : tensor<2x3xf32>
+  // CHECK: %[[IN1_ANNOTATED_1:.*]] = mesh.shard %[[IN1_ANNOTATED_0]] to <@mesh_2, {{\[}}[0]]> annotate_for_users : tensor<2x3xf32>
+  // CHECK: %[[IN2_ANNOTATED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
+  // CHECK: %[[DPS_OUT_ANNOTATED:.*]] = mesh.shard %[[DPS_OUT]] to <@mesh_2, {{\[}}[0]]> annotate_for_users : tensor<2x2xf32>
+  %arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
+
+  // CHECK: %[[RES:.*]] = linalg.matmul ins(%[[IN1_ANNOTATED_1]], %[[IN2_ANNOTATED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+  // CHECK-SAME:  outs(%[[DPS_OUT_ANNOTATED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+  %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+    outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+
+  // CHECK: %[[RES_ANNOTATED_0:.*]] = mesh.shard %[[RES]] to <@mesh_2, {{\[}}[0]]> : tensor<2x2xf32>
+  // CHECK: %[[RES_ANNOTATED_1:.*]] = mesh.shard %[[RES_ANNOTATED_0]] to <@mesh_2, {{\[}}[]]> annotate_for_users : tensor<2x2xf32>
+  %res_sharded = mesh.shard %res to <@mesh_2, [[]]> annotate_for_users : tensor<2x2xf32>
+
+  // CHECK: return %[[RES_ANNOTATED_1]] : tensor<2x2xf32>
+  return %res_sharded : tensor<2x2xf32>
+}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 270787ab518831..f419b709df9f56 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -1,5 +1,6 @@
-// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation,cse))" %s | FileCheck %s
 
+mesh.mesh @mesh_2(shape = 2)
 mesh.mesh @mesh_1d(shape = ?)
 mesh.mesh @mesh_2d(shape = 2x4)
 mesh.mesh @mesh_3d(shape = ?x?x?)
@@ -73,12 +74,11 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor
   // CHECK-NEXT:  %[[V5:.*]] = tosa.abs %[[V4]]
   // CHECK-NEXT:  %[[V6:.*]] = mesh.shard %[[V5]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
   %1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V7:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
-  // CHECK-NEXT:  %[[V8:.*]] = tosa.negate %[[V7]]
-  // CHECK-NEXT:  %[[V9:.*]] = mesh.shard %[[V8]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+  // CHECK-NEXT:  %[[V7:.*]] = tosa.negate %[[V4]]
+  // CHECK-NEXT:  %[[V8:.*]] = mesh.shard %[[V7]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
   %2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
   %3 = mesh.shard %2 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
-  // CHECK-NEXT: return %[[V6]], %[[V9]]
+  // CHECK-NEXT: return %[[V6]], %[[V8]]
   return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
 }
 
@@ -135,6 +135,36 @@ func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg
   return %2 : tensor<2x16x32xf32>
 }
 
+// CHECK-LABEL: func.func @resolve_conflicting_annotations
+func.func @resolve_conflicting_annotations(
+  // CHECK-SAME: %[[IN1:.*]]: tensor<2x3xf32>,
+  %arg0: tensor<2x3xf32>,
+  // CHECK-SAME: %[[IN2:.*]]: tensor<3x2xf32>,
+  %arg1: tensor<3x2xf32>,
+  // CHECK-SAME: %[[OUT_DPS:.*]]: tensor<2x2xf32>
+  %out_dps: tensor<2x2xf32>
+// CHECK-SAME: ) -> tensor<2x2xf32> {
+) -> tensor<2x2xf32> {
+  // CHECK: %[[IN1_SHARDED1:.*]] = mesh.shard %[[IN1]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x3xf32>
+  // CHECK: %[[IN1_SHARDED2:.*]] = mesh.shard %[[IN1_SHARDED1]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x3xf32>
+  // CHECK: %[[IN2_SHARDED:.*]] = mesh.shard %[[IN2]] to <@mesh_2, []> annotate_for_users : tensor<3x2xf32>
+  // CHECK: %[[OUT_DPS_SHARDED:.*]] = mesh.shard %[[OUT_DPS]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x2xf32>
+  %arg0_sharded = mesh.shard %arg0 to <@mesh_2, [[0]]> : tensor<2x3xf32>
+
+  // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[IN1_SHARDED2]], %[[IN2_SHARDED]] : tensor<2x3xf32>, tensor<3x2xf32>)
+  // CHECK-SAME: outs(%[[OUT_DPS_SHARDED]] : tensor<2x2xf32>) -> tensor<2x2xf32>
+  %res = linalg.matmul ins(%arg0_sharded, %arg1 : tensor<2x3xf32>, tensor<3x2xf32>)
+    outs(%out_dps : tensor<2x2xf32>) -> tensor<2x2xf32>
+
+  // CHECK: %[[MATMUL_SHARDED1:.*]] = mesh.shard %[[MATMUL]] to <@mesh_2, {{\[\[}}0]]> : tensor<2x2xf32>
+  // CHECK: %[[MATMUL_SHARDED2:.*]] = mesh.shard %[[MATMUL_SHARDED1]] to <@mesh_2, {{\[\[}}0]]> annotate_for_users : tensor<2x2xf32>
+  // CHECK: %[[MATMUL_SHARDED3:.*]] = mesh.shard %[[MATMUL_SHARDED2]] to <@mesh_2, {{\[\[}}]]> : tensor<2x2xf32>
+  %res_sharded = mesh.shard %res to <@mesh_2, [[]]> : tensor<2x2xf32>
+
+  // CHECK: return %[[MATMUL_SHARDED3]] : tensor<2x2xf32>
+  return %res_sharded : tensor<2x2xf32>
+}
+
 // https://arxiv.org/abs/2211.05102 Figure 2(a)
 // CHECK-LABEL: func.func @mlp_1d_weight_stationary
 // CHECK-SAME:     %[[ARG0:.*]]: tensor<2x4x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<2x32x8xf32>
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 2df247aba35155..d7a1e2fd9d2790 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -16,6 +16,21 @@ func.func @full_replication(
   return %1 : tensor<2xi8>
 }
 
+// CHECK-LABEL: func @sharding_triplet
+func.func @sharding_triplet(
+  // CHECK-SAME: %[[ARG:.*]]: tensor<1xf32>
+  %arg0: tensor<2xf32>
+// CHECK-SAME: ) -> tensor<2xf32> {
+) -> tensor<2xf32> {
+  // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<1xf32> -> tensor<2xf32>
+  %sharding_annotated = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<2xf32>
+  %sharding_annotated_0 = mesh.shard %sharding_annotated to <@mesh_1d, [[0]]> annotate_for_users : tensor<2xf32>
+  %sharding_annotated_1 = mesh.shard %sharding_annotated_0 to <@mesh_1d, [[]]> : tensor<2xf32>
+  // CHECK: return %[[ALL_GATHER]] : tensor<2xf32>
+  return %sharding_annotated_1 : tensor<2xf32>
+}
+
+
 // CHECK-LABEL: func @move_split_axis
 func.func @move_split_axis(
   // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi8>



More information about the Mlir-commits mailing list