[Mlir-commits] [mlir] [MLIR][shard] checking for correct&full sharding annotations (PR #176000)

Frank Schlimbach llvmlistbot at llvm.org
Thu Jan 15 07:27:20 PST 2026


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/176000

>From 651e91e44b110efd834e029e0f1375c132cb6889 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 14 Jan 2026 09:46:36 -0800
Subject: [PATCH 1/2] checking for correct&full sharding annotations

---
 .../Dialect/Shard/Transforms/Partition.cpp    | 30 +++++++++++++++++++
 1 file changed, 30 insertions(+)

diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 335ca1a60f8f3..bf0617c107f5e 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -663,6 +663,32 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
   return success();
 }
 
+// Check if the operation is correctly and fully annotated with sharding
+// information:
+//   - Operation results must have exactly one use (e.g. the shard operation).
+//   - All operands and all results must be annotated, e.g. they must be
+//     produced by/consumed by a shard.shard operation.
+//   - Result annotations must not include the 'annotate_for_users' attribute.
+//   - Operand annotations must include the 'annotate_for_users' attribute.
+// raises an error if the operation is not correctly and fully annotated.
+static void checkFullyAnnotated(Operation *op) {
+  // constant ops do not need to have sharding annotations
+  if (op->hasTrait<OpTrait::ConstantLike>())
+    return;
+  for (auto operand : op->getOperands()) {
+    if (!operand.getDefiningOp<ShardOp>())
+      op->emitError("Cannot partition: all operands must be produced by a "
+                    "shard.shard operation");
+  }
+  for (auto result : op->getResults()) {
+    if (!result.hasOneUse())
+      op->emitError("Cannot partition: all results must have exactly one use");
+    if (!(isa<ShardOp>(*result.user_begin())))
+      op->emitError(
+          "Cannot partition: all result users must be shard.shard operations");
+  }
+}
+
 static LogicalResult
 partitionOperation(Operation &op, IRMapping &partitionMap,
                    SymbolTableCollection &symbolTableCollection,
@@ -670,6 +696,7 @@ partitionOperation(Operation &op, IRMapping &partitionMap,
   if (isa<ShardingOp>(op)) {
     return success();
   }
+
   if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
     auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
     if (!shardOp) {
@@ -686,6 +713,9 @@ partitionOperation(Operation &op, IRMapping &partitionMap,
                               builder);
   }
 
+  // check if operation is correctly and fully annotated
+  checkFullyAnnotated(&op);
+
   SmallVector<Value> partitionedOperands;
   llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
                   [&partitionMap](Value operand) {

>From 700bf198b4e0ee5f773b8edef417382ecfda5fd6 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 15 Jan 2026 07:18:11 -0800
Subject: [PATCH 2/2] checking block args, adding tests

---
 .../Dialect/Shard/Transforms/Partition.cpp    | 73 ++++++++++++++++---
 mlir/test/Dialect/Arith/shard-partition.mlir  |  3 +-
 .../test/Dialect/Shard/invalid_annotated.mlir | 54 ++++++++++++++
 mlir/test/Dialect/Shard/partition.mlir        | 13 ++--
 4 files changed, 126 insertions(+), 17 deletions(-)
 create mode 100644 mlir/test/Dialect/Shard/invalid_annotated.mlir

diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index bf0617c107f5e..1daa65b35facf 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -663,6 +663,32 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
   return success();
 }
 
+// Check if the block args are correctly annotated with sharding information:
+//   - non-tensor and 0d-tensor args are ignored
+//   - each tensor arg must have exactly one use, which must be a shard.shard
+//   operation
+static LogicalResult checkFullyAnnotated(Block &block) {
+  for (auto arg : block.getArguments()) {
+    auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
+    if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0)
+      continue;
+
+    if (rankedTensorArg.getNumUses() > 1)
+      return emitError(block.getParent()->getLoc())
+             << "Cannot partition: expected a single use for block argument "
+             << arg.getArgNumber() << " in block "
+             << block.computeBlockNumber();
+    Operation *useOp = *rankedTensorArg.getUsers().begin();
+    ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
+    if (!shardOp)
+      return emitError(block.getParent()->getLoc())
+             << "Cannot partition: expected a shard.shard op for block "
+             << "argument " << arg.getArgNumber() << " in block "
+             << block.computeBlockNumber();
+  }
+  return success();
+}
+
 // Check if the operation is correctly and fully annotated with sharding
 // information:
 //   - Operation results must have exactly one use (e.g. the shard operation).
@@ -671,22 +697,43 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
 //   - Result annotations must not include the 'annotate_for_users' attribute.
 //   - Operand annotations must include the 'annotate_for_users' attribute.
 // raises an error if the operation is not correctly and fully annotated.
-static void checkFullyAnnotated(Operation *op) {
+static LogicalResult checkFullyAnnotated(Operation *op) {
   // constant ops do not need to have sharding annotations
   if (op->hasTrait<OpTrait::ConstantLike>())
-    return;
-  for (auto operand : op->getOperands()) {
-    if (!operand.getDefiningOp<ShardOp>())
-      op->emitError("Cannot partition: all operands must be produced by a "
-                    "shard.shard operation");
+    return success();
+
+  for (auto &operand : op->getOpOperands()) {
+    // non-tensor and 0d-tensor operands are ignored
+    auto rankedTT = dyn_cast<RankedTensorType>(operand.get().getType());
+    if (!rankedTT || rankedTT.getRank() == 0)
+      continue;
+
+    auto shard = operand.get().getDefiningOp<ShardOp>();
+    if (!shard)
+      return op->emitError() << "Cannot partition: tensor operand "
+                             << operand.getOperandNumber()
+                             << " must be defined by a shard.shard operation.";
+    else if (!shard.getAnnotateForUsers())
+      return op->emitError()
+             << "Cannot partition: shard.shard for operand "
+             << operand.getOperandNumber() << " must set 'annotate_for_users'.";
   }
   for (auto result : op->getResults()) {
     if (!result.hasOneUse())
-      op->emitError("Cannot partition: all results must have exactly one use");
-    if (!(isa<ShardOp>(*result.user_begin())))
-      op->emitError(
-          "Cannot partition: all result users must be shard.shard operations");
+      return op->emitError()
+             << "Cannot partition: result " << result.getResultNumber()
+             << " must have exactly one use.";
+    auto shard = dyn_cast<ShardOp>(*result.user_begin());
+    if (!shard)
+      return op->emitError()
+             << "Cannot partition: user of result " << result.getResultNumber()
+             << " must be shard.shard operation.";
+    else if (shard.getAnnotateForUsers())
+      return op->emitError() << "Cannot partition: shard.shard for result "
+                             << result.getResultNumber()
+                             << " must not set 'annotate_for_users'.";
   }
+  return success();
 }
 
 static LogicalResult
@@ -714,7 +761,8 @@ partitionOperation(Operation &op, IRMapping &partitionMap,
   }
 
   // check if operation is correctly and fully annotated
-  checkFullyAnnotated(&op);
+  if (failed(checkFullyAnnotated(&op)))
+    return failure();
 
   SmallVector<Value> partitionedOperands;
   llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
@@ -732,6 +780,9 @@ partitionBlock(Block &block, IRMapping &partitionMap,
                SymbolTableCollection &symbolTableCollection,
                OpBuilder &builder) {
 
+  if (failed(checkFullyAnnotated(block)))
+    return failure();
+
   SmallVector<Location> argLocations;
   llvm::transform(block.getArguments(), std::back_inserter(argLocations),
                   [](BlockArgument arg) { return arg.getLoc(); });
diff --git a/mlir/test/Dialect/Arith/shard-partition.mlir b/mlir/test/Dialect/Arith/shard-partition.mlir
index be894278e5e95..833d371439bb5 100644
--- a/mlir/test/Dialect/Arith/shard-partition.mlir
+++ b/mlir/test/Dialect/Arith/shard-partition.mlir
@@ -13,5 +13,6 @@ func.func @test_partition_constant() ->(tensor<1024x1024xf32>)attributes{llvm.em
   %sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
   %sharded_1 = shard.shard %cst to %sharding_1 : tensor<1024x1024xf32>
   %ci = arith.constant 434 : i32
-  return %sharded_1 : tensor<1024x1024xf32>
+  %sharded_r = shard.shard %sharded_1 to %sharding_1 annotate_for_users : tensor<1024x1024xf32>
+  return %sharded_r : tensor<1024x1024xf32>
 }
diff --git a/mlir/test/Dialect/Shard/invalid_annotated.mlir b/mlir/test/Dialect/Shard/invalid_annotated.mlir
new file mode 100644
index 0000000000000..06bc1eba5516b
--- /dev/null
+++ b/mlir/test/Dialect/Shard/invalid_annotated.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt \
+// RUN:   --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \
+// RUN:   -verify-diagnostics %s
+
+shard.grid @grid(shape = 2)
+
+// expected-error @+1 {{Cannot partition: expected a shard.shard op for block argument 0 in block 0}}
+func.func @test_block_arg_missing_shard(%arg0: tensor<6xi32>) -> tensor<6xi32> {
+  %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+  %2 = tosa.abs %arg0 : (tensor<6xi32>) -> tensor<6xi32>
+  %sharded = shard.shard %2 to %sharding annotate_for_users : tensor<6xi32>
+  return %2 : tensor<6xi32>
+}
+
+func.func @test_operand_missing_annotate(%arg0: tensor<6xi32>) -> tensor<6xi32> {
+  %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+  %2 = shard.shard %arg0 to %sharding : tensor<6xi32>
+  // expected-error @+1 {{Cannot partition: shard.shard for operand 0 must set 'annotate_for_users'.}}
+  %3 = tosa.rsqrt %2 : (tensor<6xi32>) -> tensor<6xi32>
+  %4 = tosa.rsqrt %3 : (tensor<6xi32>) -> tensor<6xi32>
+  %sharded = shard.shard %4 to %sharding annotate_for_users : tensor<6xi32>
+  return %sharded : tensor<6xi32>
+}
+
+func.func @test_result_missing_sharding() -> tensor<6xi32> {
+  %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+  // expected-error @+1 {{Cannot partition: user of result 0 must be shard.shard operation.}}
+  %1 = tensor.empty() : tensor<6xi32>
+  %3 = tosa.rsqrt %1 : (tensor<6xi32>) -> tensor<6xi32>
+  %4 = shard.shard %3 to %sharding : tensor<6xi32>
+  %sharded = shard.shard %4 to %sharding annotate_for_users : tensor<6xi32>
+  return %sharded : tensor<6xi32>
+}
+
+func.func @test_multiple_users(%arg0: tensor<6xi32>) -> tensor<6xi32> {
+  %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+  %1 = shard.shard %arg0 to %sharding : tensor<6xi32>
+  %2 = shard.shard %1 to %sharding annotate_for_users : tensor<6xi32>
+  // expected-error @+1 {{Cannot partition: result 0 must have exactly one use.}}
+  %3 = tosa.rsqrt %2 : (tensor<6xi32>) -> tensor<6xi32>
+  %4 = shard.shard %3 to %sharding : tensor<6xi32>
+  %sharded = shard.shard %3 to %sharding annotate_for_users : tensor<6xi32>
+  return %sharded : tensor<6xi32>
+}
+
+func.func @test_result_invalid_annotate(%arg0: tensor<6xi32>) -> tensor<6xi32> {
+  %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+  %1 = shard.shard %arg0 to %sharding : tensor<6xi32>
+  %2 = shard.shard %1 to %sharding annotate_for_users : tensor<6xi32>
+  // expected-error @+1 {{Cannot partition: shard.shard for result 0 must not set 'annotate_for_users'.}}
+  %3 = tosa.rsqrt %2 : (tensor<6xi32>) -> tensor<6xi32>
+  %sharded = shard.shard %3 to %sharding annotate_for_users : tensor<6xi32>
+  return %sharded : tensor<6xi32>
+}
diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir
index c2572cc3b987b..0f293a39608e3 100644
--- a/mlir/test/Dialect/Shard/partition.mlir
+++ b/mlir/test/Dialect/Shard/partition.mlir
@@ -14,8 +14,9 @@ func.func @return_sharding(
   %sharded = shard.shard %arg0 to %ssharded  : tensor<2xf32>
   // CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}0]] : !shard.sharding
   %r = shard.get_sharding %sharded : tensor<2xf32> -> !shard.sharding
+  %sharded_r = shard.shard %sharded to %ssharded annotate_for_users : tensor<2xf32>
   // CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !shard.sharding
-  return %sharded, %r : tensor<2xf32>, !shard.sharding
+  return %sharded_r, %r : tensor<2xf32>, !shard.sharding
 }
 
 // CHECK-LABEL: func @full_replication
@@ -44,7 +45,7 @@ func.func @sharding_triplet(
   %ssharded_0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
   %sharded_0 = shard.shard %sharded to %ssharded_0  annotate_for_users : tensor<2xf32>
   %ssharded_1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
-  %sharded_1 = shard.shard %sharded_0 to %ssharded_1  : tensor<2xf32>
+  %sharded_1 = shard.shard %sharded_0 to %ssharded_1 annotate_for_users : tensor<2xf32>
   // CHECK: return %[[ALL_GATHER]] : tensor<2xf32>
   return %sharded_1 : tensor<2xf32>
 }
@@ -197,9 +198,10 @@ func.func @incomplete_sharding(
   // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
   %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
   %s2 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
-  %2 = shard.shard %1 to %s2  : tensor<8x16xf32>
+  %2 = shard.shard %1 to %s2 : tensor<8x16xf32>
+  %3 = shard.shard %2 to %s2 annotate_for_users : tensor<8x16xf32>
   // CHECK: return %[[RES]] : tensor<4x16xf32>
-  return %2 : tensor<8x16xf32>
+  return %3 : tensor<8x16xf32>
 }
 
 shard.grid @grid_1d_4(shape = 4)
@@ -301,7 +303,8 @@ func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) {
   %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
   %sharded = shard.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
   %4 = tensor.empty() : tensor<6xi32>
-  %sharded_out = shard.shard %4 to %sharding : tensor<6xi32>
+  %sharded_4 = shard.shard %4 to %sharding : tensor<6xi32>
+  %sharded_out = shard.shard %sharded_4 to %sharding annotate_for_users : tensor<6xi32>
   %sharded_in = shard.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
   // CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
   %reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<6xi32>) dimensions = [1] 



More information about the Mlir-commits mailing list