[Mlir-commits] [mlir] [mlir][mesh] adding option for traversal order in sharding propagation (PR #144079)

Frank Schlimbach llvmlistbot at llvm.org
Tue Jun 17 09:43:43 PDT 2025


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/144079

>From 2fa80040cb416fc6ecf225838e9eb7cc12f88c8c Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 4 Apr 2025 08:54:33 +0200
Subject: [PATCH 1/4] adding option for traversal order in sharding propagation

---
 .../mlir/Dialect/Mesh/Transforms/Passes.h     | 12 +++++
 .../mlir/Dialect/Mesh/Transforms/Passes.td    | 15 ++++++
 .../Mesh/Transforms/ShardingPropagation.cpp   | 40 +++++++++++----
 .../Mesh/forward-sharding-propagation.mlir    | 49 +++++++++++++++++++
 4 files changed, 106 insertions(+), 10 deletions(-)
 create mode 100644 mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir

diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index 83399d10beaae..a2424d43a8ba9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -19,6 +19,18 @@ class FuncOp;
 
 namespace mesh {
 
+/// This enum controls the traversal order for the sharding propagation.
+enum class TraversalOrder {
+  /// Forward traversal.
+  Forward,
+  /// Backward traversal.
+  Backward,
+  /// Forward then backward traversal.
+  ForwardBackward,
+  /// Backward then forward traversal.
+  BackwardForward
+};
+
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 06ebf151e7d64..11ec7e78cd5e6 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -24,6 +24,21 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
     operation, and the operations themselves are added with sharding option
     attributes.
   }];
+  let options = [
+    Option<"traversal", "traversal",
+           "mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
+           "Traversal order to use for sharding propagation:",
+            [{::llvm::cl::values(
+              clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
+              "Forward only traversal."),
+              clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
+              "backward only traversal."),
+              clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
+              "forward-backward traversal."),
+              clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
+              "backward-forward traversal.")
+            )}]>,
+  ];
   let dependentDialects = [
     "mesh::MeshDialect"
   ];
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 4452dd65fce9d..9d4a144912ee2 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -362,6 +362,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
 //===----------------------------------------------------------------------===//
 struct ShardingPropagation
     : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
+
+  using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
+
   void runOnOperation() override {
     FunctionOpInterface funcOp = getOperation();
     MLIRContext *ctx = funcOp.getContext();
@@ -383,17 +386,34 @@ struct ShardingPropagation
         });
 
     // 1. propagate in reversed order
-    for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
-      if (failed(visitOp(&op, builder)))
-        return signalPassFailure();
-
-    LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
-                      << funcOp << "\n");
-    LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+    if (traversal == TraversalOrder::Backward ||
+        traversal == TraversalOrder::BackwardForward) {
+      for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+        if (failed(visitOp(&op, builder)))
+          return signalPassFailure();
+      if (traversal == TraversalOrder::BackwardForward) {
+        LLVM_DEBUG(DBGS() << "After backward order propagation:\n"
+                          << funcOp << "\n");
+        LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+      }
+    }
 
     // 2. propagate in original order
-    for (Operation &op : llvm::make_early_inc_range(block))
-      if (failed(visitOp(&op, builder)))
-        return signalPassFailure();
+    if (traversal != TraversalOrder::Backward) {
+      for (Operation &op : llvm::make_early_inc_range(block))
+        if (failed(visitOp(&op, builder)))
+          return signalPassFailure();
+      if (traversal == TraversalOrder::ForwardBackward) {
+        LLVM_DEBUG(DBGS() << "After forward order propagation:\n"
+                          << funcOp << "\n");
+        LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+      }
+    }
+
+    // 3. propagate in backward order if needed
+    if (traversal == TraversalOrder::ForwardBackward)
+      for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+        if (failed(visitOp(&op, builder)))
+          return signalPassFailure();
   }
 };
diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
new file mode 100644
index 0000000000000..98e9931b8de94
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
+  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+  func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
+    %c1_i32 = arith.constant 1 : i32
+    // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
+    %0 = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[v1:%.*]] = linalg.fill ins
+    // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
+    %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
+    %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+    %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
+    // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
+    %3 = tensor.empty() : tensor<6x6xi32>
+    // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
+    // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+    // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
+    // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
+    %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
+        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
+    ^bb0(%in: i32, %in_2: i32, %out: i32):
+      %9 = arith.addi %in, %in_2 : i32
+      linalg.yield %9 : i32
+    } -> tensor<6x6xi32>
+    %c0_i32 = arith.constant 0 : i32
+    %6 = tensor.empty() : tensor<i32>
+    %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
+    // CHECK: [[vreduced:%.*]] = linalg.reduce ins
+    // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial =  sum [0] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
+    %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1] 
+      (%in: i32, %init: i32) {
+        %9 = arith.addi %in, %init : i32
+        linalg.yield %9 : i32
+      }
+    // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
+    %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
+    // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
+    %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
+    return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
+  }
+}

>From 6d7d2d355af776d8ca24e62b091133ceaa7a835b Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 17 Jun 2025 18:42:06 +0200
Subject: [PATCH 2/4] fixing invalid modification fo use-range while iterating

---
 mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h |  3 ---
 mlir/lib/Dialect/Mesh/IR/MeshOps.cpp        | 27 ++++++++++++---------
 2 files changed, 15 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 32c2eca2cefa8..3878505f8f93f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -206,9 +206,6 @@ Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
 // Use newShardOp if it is not null. Otherwise create a new one.
 // May insert resharding if required.
 // Potentially updates newShardOp.
-void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
-                                         OpOperand &operand, OpBuilder &builder,
-                                         ShardOp &newShardOp);
 void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult result,
                                          OpBuilder &builder);
 void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 304cb55a35086..a2c2d1a7470cc 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -275,13 +275,12 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
   return type;
 }
 
-void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
-                                                     OpOperand &operand,
-                                                     OpBuilder &builder,
-                                                     ShardOp &newShardOp) {
+static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
+                                                    Value &operandValue,
+                                                    Operation *operandOp,
+                                                    OpBuilder &builder,
+                                                    ShardOp &newShardOp) {
   OpBuilder::InsertionGuard insertionGuard(builder);
-  Value operandValue = operand.get();
-  Operation *operandOp = operand.getOwner();
   builder.setInsertionPointAfterValue(operandValue);
   ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
   if (shardOp && sharding == shardOp.getSharding() &&
@@ -300,9 +299,8 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
         builder.create<ShardOp>(operandValue.getLoc(), operandValue, shardingOp,
                                 /*annotate_for_users*/ false);
   }
-  IRRewriter rewriter(builder);
-  rewriter.replaceUsesWithIf(
-      operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
+  operandValue.replaceUsesWithIf(
+      newShardOp, [operandOp, operandValue](OpOperand &use) {
         return use.getOwner() == operandOp && use.get() == operandValue;
       });
 
@@ -313,15 +311,20 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
   auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
                                              newShardOp.getSharding(),
                                              /*annotate_for_users*/ true);
-  rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
+  newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
 }
 
 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                                      OpResult result,
                                                      OpBuilder &builder) {
   ShardOp newShardOp;
-  for (auto &use : llvm::make_early_inc_range(result.getUses())) {
-    maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
+  SmallVector<std::pair<Value, Operation *>> uses;
+  for (auto &use : result.getUses()) {
+    uses.emplace_back(use.get(), use.getOwner());
+  }
+  for (auto &[operandValue, operandOp] : uses) {
+    maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
+                                            builder, newShardOp);
   }
 }
 

>From 3f60b3af884974621d4a67775becd7dff363142e Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 17 Jun 2025 18:42:48 +0200
Subject: [PATCH 3/4] code deduplication

---
 .../Mesh/Transforms/ShardingPropagation.cpp   | 48 +++++++++----------
 1 file changed, 22 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 9d4a144912ee2..6751fafaf1776 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -385,35 +385,31 @@ struct ShardingPropagation
             shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
         });
 
-    // 1. propagate in reversed order
-    if (traversal == TraversalOrder::Backward ||
-        traversal == TraversalOrder::BackwardForward) {
-      for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
-        if (failed(visitOp(&op, builder)))
-          return signalPassFailure();
-      if (traversal == TraversalOrder::BackwardForward) {
-        LLVM_DEBUG(DBGS() << "After backward order propagation:\n"
-                          << funcOp << "\n");
-        LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+    auto traverse = [&](auto &&range, OpBuilder &builder,
+                        const char *order) -> bool {
+      for (Operation &op : range) {
+        if (failed(visitOp(&op, builder))) {
+          signalPassFailure();
+          return true;
+        }
       }
-    }
+      LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
+                        << funcOp << "\n");
+      LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+      return false;
+    };
 
-    // 2. propagate in original order
-    if (traversal != TraversalOrder::Backward) {
-      for (Operation &op : llvm::make_early_inc_range(block))
-        if (failed(visitOp(&op, builder)))
-          return signalPassFailure();
-      if (traversal == TraversalOrder::ForwardBackward) {
-        LLVM_DEBUG(DBGS() << "After forward order propagation:\n"
-                          << funcOp << "\n");
-        LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
-      }
-    }
+    // 1. Propagate in reversed order.
+    if (traversal == TraversalOrder::Backward ||
+        traversal == TraversalOrder::BackwardForward)
+      traverse(llvm::reverse(block), builder, "backward");
+
+    // 2. Propagate in original order.
+    if (traversal != TraversalOrder::Backward)
+      traverse(block, builder, "forward");
 
-    // 3. propagate in backward order if needed
+    // 3. Propagate in backward order if needed.
     if (traversal == TraversalOrder::ForwardBackward)
-      for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
-        if (failed(visitOp(&op, builder)))
-          return signalPassFailure();
+      traverse(llvm::reverse(block), builder, "backward");
   }
 };

>From e0678d30275c54ae747abfa6da5e2ed6cf1db5cc Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Tue, 17 Jun 2025 18:43:31 +0200
Subject: [PATCH 4/4] adding tests for forward and forward-backward sharding
 propagation

---
 .../Mesh/backward-sharding-propagation.mlir   | 26 ++++++++++++++++++
 ...forward-backward-sharding-propagation.mlir | 27 +++++++++++++++++++
 2 files changed, 53 insertions(+)
 create mode 100644 mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
 create mode 100644 mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir

diff --git a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
new file mode 100644
index 0000000000000..4223d01d65111
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=backward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+  func.func @test_forward() -> tensor<6x6xi32> {
+    %c1_i32 = arith.constant 1 : i32
+    // CHECK: tensor.empty()
+    %0 = tensor.empty() : tensor<6x6xi32>
+    %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+    // CHECK-COUNT-2: mesh.shard
+    %sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32>
+    %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : tensor<6x6xi32>) -> tensor<6x6xi32>
+    // CHECK: tensor.empty()
+    // CHECK-NOT: mesh.shard @
+    %2 = tensor.empty() : tensor<6x6xi32>
+    %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%1, %1
+        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
+    ^bb0(%in: i32, %in_2: i32, %out: i32):
+      %9 = arith.addi %in, %in_2 : i32
+      linalg.yield %9 : i32
+    } -> tensor<6x6xi32>
+    // CHECK: return
+    return %3 : tensor<6x6xi32>
+  }
+}
diff --git a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
new file mode 100644
index 0000000000000..817fd6ae871fc
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward-backward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+  func.func @test_forward() -> tensor<6x6xi32> {
+    %c1_i32 = arith.constant 1 : i32
+    // CHECK: tensor.empty()
+    %0 = tensor.empty() : tensor<6x6xi32>
+    // CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]]
+    %sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+    %annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32>
+    %1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) -> tensor<6x6xi32>
+    %2 = tensor.empty() : tensor<6x6xi32>
+    // CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]]
+    %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%2, %1
+        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
+    ^bb0(%in: i32, %in_2: i32, %out: i32):
+      %9 = arith.addi %in, %in_2 : i32
+      linalg.yield %9 : i32
+    } -> tensor<6x6xi32>
+    %sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding
+    %annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32>
+    // CHECK: return
+    return %annotated_col : tensor<6x6xi32>
+  }
+}
\ No newline at end of file



More information about the Mlir-commits mailing list