[Mlir-commits] [mlir] [MLIR][shard] checking for correct&full sharding annotations (PR #176000)
Frank Schlimbach
llvmlistbot at llvm.org
Thu Jan 15 07:27:20 PST 2026
https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/176000
>From 651e91e44b110efd834e029e0f1375c132cb6889 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 14 Jan 2026 09:46:36 -0800
Subject: [PATCH 1/2] checking for correct&full sharding annotations
---
.../Dialect/Shard/Transforms/Partition.cpp | 30 +++++++++++++++++++
1 file changed, 30 insertions(+)
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 335ca1a60f8f3..bf0617c107f5e 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -663,6 +663,32 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
return success();
}
+// Check if the operation is correctly and fully annotated with sharding
+// information:
+// - Operation results must have exactly one use (e.g. the shard operation).
+// - All operands and all results must be annotated, e.g. they must be
+// produced by/consumed by a shard.shard operation.
+// - Result annotations must not include the 'annotate_for_users' attribute.
+// - Operand annotations must include the 'annotate_for_users' attribute.
+// raises an error if the operation is not correctly and fully annotated.
+static void checkFullyAnnotated(Operation *op) {
+ // constant ops do not need to have sharding annotations
+ if (op->hasTrait<OpTrait::ConstantLike>())
+ return;
+ for (auto operand : op->getOperands()) {
+ if (!operand.getDefiningOp<ShardOp>())
+ op->emitError("Cannot partition: all operands must be produced by a "
+ "shard.shard operation");
+ }
+ for (auto result : op->getResults()) {
+ if (!result.hasOneUse())
+ op->emitError("Cannot partition: all results must have exactly one use");
+ if (!(isa<ShardOp>(*result.user_begin())))
+ op->emitError(
+ "Cannot partition: all result users must be shard.shard operations");
+ }
+}
+
static LogicalResult
partitionOperation(Operation &op, IRMapping &partitionMap,
SymbolTableCollection &symbolTableCollection,
@@ -670,6 +696,7 @@ partitionOperation(Operation &op, IRMapping &partitionMap,
if (isa<ShardingOp>(op)) {
return success();
}
+
if (auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
if (!shardOp) {
@@ -686,6 +713,9 @@ partitionOperation(Operation &op, IRMapping &partitionMap,
builder);
}
+ // check if operation is correctly and fully annotated
+ checkFullyAnnotated(&op);
+
SmallVector<Value> partitionedOperands;
llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
[&partitionMap](Value operand) {
>From 700bf198b4e0ee5f773b8edef417382ecfda5fd6 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 15 Jan 2026 07:18:11 -0800
Subject: [PATCH 2/2] checking block args, adding tests
---
.../Dialect/Shard/Transforms/Partition.cpp | 73 ++++++++++++++++---
mlir/test/Dialect/Arith/shard-partition.mlir | 3 +-
.../test/Dialect/Shard/invalid_annotated.mlir | 54 ++++++++++++++
mlir/test/Dialect/Shard/partition.mlir | 13 ++--
4 files changed, 126 insertions(+), 17 deletions(-)
create mode 100644 mlir/test/Dialect/Shard/invalid_annotated.mlir
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index bf0617c107f5e..1daa65b35facf 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -663,6 +663,32 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
return success();
}
+// Check if the block args are correctly annotated with sharding information:
+// - non-tensor and 0d-tensor args are ignored
+// - each tensor arg must have exactly one use, which must be a shard.shard
+// operation
+static LogicalResult checkFullyAnnotated(Block &block) {
+ for (auto arg : block.getArguments()) {
+ auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
+ if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0)
+ continue;
+
+ if (rankedTensorArg.getNumUses() > 1)
+ return emitError(block.getParent()->getLoc())
+ << "Cannot partition: expected a single use for block argument "
+ << arg.getArgNumber() << " in block "
+ << block.computeBlockNumber();
+ Operation *useOp = *rankedTensorArg.getUsers().begin();
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
+ if (!shardOp)
+ return emitError(block.getParent()->getLoc())
+ << "Cannot partition: expected a shard.shard op for block "
+ << "argument " << arg.getArgNumber() << " in block "
+ << block.computeBlockNumber();
+ }
+ return success();
+}
+
// Check if the operation is correctly and fully annotated with sharding
// information:
// - Operation results must have exactly one use (e.g. the shard operation).
@@ -671,22 +697,43 @@ partitionOperation(ShardOp shardOp, IRMapping &partitionMap,
// - Result annotations must not include the 'annotate_for_users' attribute.
// - Operand annotations must include the 'annotate_for_users' attribute.
// raises an error if the operation is not correctly and fully annotated.
-static void checkFullyAnnotated(Operation *op) {
+static LogicalResult checkFullyAnnotated(Operation *op) {
// constant ops do not need to have sharding annotations
if (op->hasTrait<OpTrait::ConstantLike>())
- return;
- for (auto operand : op->getOperands()) {
- if (!operand.getDefiningOp<ShardOp>())
- op->emitError("Cannot partition: all operands must be produced by a "
- "shard.shard operation");
+ return success();
+
+ for (auto &operand : op->getOpOperands()) {
+ // non-tensor and 0d-tensor operands are ignored
+ auto rankedTT = dyn_cast<RankedTensorType>(operand.get().getType());
+ if (!rankedTT || rankedTT.getRank() == 0)
+ continue;
+
+ auto shard = operand.get().getDefiningOp<ShardOp>();
+ if (!shard)
+ return op->emitError() << "Cannot partition: tensor operand "
+ << operand.getOperandNumber()
+ << " must be defined by a shard.shard operation.";
+ else if (!shard.getAnnotateForUsers())
+ return op->emitError()
+ << "Cannot partition: shard.shard for operand "
+ << operand.getOperandNumber() << " must set 'annotate_for_users'.";
}
for (auto result : op->getResults()) {
if (!result.hasOneUse())
- op->emitError("Cannot partition: all results must have exactly one use");
- if (!(isa<ShardOp>(*result.user_begin())))
- op->emitError(
- "Cannot partition: all result users must be shard.shard operations");
+ return op->emitError()
+ << "Cannot partition: result " << result.getResultNumber()
+ << " must have exactly one use.";
+ auto shard = dyn_cast<ShardOp>(*result.user_begin());
+ if (!shard)
+ return op->emitError()
+ << "Cannot partition: user of result " << result.getResultNumber()
+ << " must be shard.shard operation.";
+ else if (shard.getAnnotateForUsers())
+ return op->emitError() << "Cannot partition: shard.shard for result "
+ << result.getResultNumber()
+ << " must not set 'annotate_for_users'.";
}
+ return success();
}
static LogicalResult
@@ -714,7 +761,8 @@ partitionOperation(Operation &op, IRMapping &partitionMap,
}
// check if operation is correctly and fully annotated
- checkFullyAnnotated(&op);
+ if (failed(checkFullyAnnotated(&op)))
+ return failure();
SmallVector<Value> partitionedOperands;
llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands),
@@ -732,6 +780,9 @@ partitionBlock(Block &block, IRMapping &partitionMap,
SymbolTableCollection &symbolTableCollection,
OpBuilder &builder) {
+ if (failed(checkFullyAnnotated(block)))
+ return failure();
+
SmallVector<Location> argLocations;
llvm::transform(block.getArguments(), std::back_inserter(argLocations),
[](BlockArgument arg) { return arg.getLoc(); });
diff --git a/mlir/test/Dialect/Arith/shard-partition.mlir b/mlir/test/Dialect/Arith/shard-partition.mlir
index be894278e5e95..833d371439bb5 100644
--- a/mlir/test/Dialect/Arith/shard-partition.mlir
+++ b/mlir/test/Dialect/Arith/shard-partition.mlir
@@ -13,5 +13,6 @@ func.func @test_partition_constant() ->(tensor<1024x1024xf32>)attributes{llvm.em
%sharding_1 = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding
%sharded_1 = shard.shard %cst to %sharding_1 : tensor<1024x1024xf32>
%ci = arith.constant 434 : i32
- return %sharded_1 : tensor<1024x1024xf32>
+ %sharded_r = shard.shard %sharded_1 to %sharding_1 annotate_for_users : tensor<1024x1024xf32>
+ return %sharded_r : tensor<1024x1024xf32>
}
diff --git a/mlir/test/Dialect/Shard/invalid_annotated.mlir b/mlir/test/Dialect/Shard/invalid_annotated.mlir
new file mode 100644
index 0000000000000..06bc1eba5516b
--- /dev/null
+++ b/mlir/test/Dialect/Shard/invalid_annotated.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt \
+// RUN: --pass-pipeline="builtin.module(func.func(shard-partition,test-single-fold))" \
+// RUN: -verify-diagnostics %s
+
+shard.grid @grid(shape = 2)
+
+// expected-error @+1 {{Cannot partition: expected a shard.shard op for block argument 0 in block 0}}
+func.func @test_block_arg_missing_shard(%arg0: tensor<6xi32>) -> tensor<6xi32> {
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %2 = tosa.abs %arg0 : (tensor<6xi32>) -> tensor<6xi32>
+ %sharded = shard.shard %2 to %sharding annotate_for_users : tensor<6xi32>
+ return %2 : tensor<6xi32>
+}
+
+func.func @test_operand_missing_annotate(%arg0: tensor<6xi32>) -> tensor<6xi32> {
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %2 = shard.shard %arg0 to %sharding : tensor<6xi32>
+ // expected-error @+1 {{Cannot partition: shard.shard for operand 0 must set 'annotate_for_users'.}}
+ %3 = tosa.rsqrt %2 : (tensor<6xi32>) -> tensor<6xi32>
+ %4 = tosa.rsqrt %3 : (tensor<6xi32>) -> tensor<6xi32>
+ %sharded = shard.shard %4 to %sharding annotate_for_users : tensor<6xi32>
+ return %sharded : tensor<6xi32>
+}
+
+func.func @test_result_missing_sharding() -> tensor<6xi32> {
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ // expected-error @+1 {{Cannot partition: user of result 0 must be shard.shard operation.}}
+ %1 = tensor.empty() : tensor<6xi32>
+ %3 = tosa.rsqrt %1 : (tensor<6xi32>) -> tensor<6xi32>
+ %4 = shard.shard %3 to %sharding : tensor<6xi32>
+ %sharded = shard.shard %4 to %sharding annotate_for_users : tensor<6xi32>
+ return %sharded : tensor<6xi32>
+}
+
+func.func @test_multiple_users(%arg0: tensor<6xi32>) -> tensor<6xi32> {
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %1 = shard.shard %arg0 to %sharding : tensor<6xi32>
+ %2 = shard.shard %1 to %sharding annotate_for_users : tensor<6xi32>
+ // expected-error @+1 {{Cannot partition: result 0 must have exactly one use.}}
+ %3 = tosa.rsqrt %2 : (tensor<6xi32>) -> tensor<6xi32>
+ %4 = shard.shard %3 to %sharding : tensor<6xi32>
+ %sharded = shard.shard %3 to %sharding annotate_for_users : tensor<6xi32>
+ return %sharded : tensor<6xi32>
+}
+
+func.func @test_result_invalid_annotate(%arg0: tensor<6xi32>) -> tensor<6xi32> {
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %1 = shard.shard %arg0 to %sharding : tensor<6xi32>
+ %2 = shard.shard %1 to %sharding annotate_for_users : tensor<6xi32>
+ // expected-error @+1 {{Cannot partition: shard.shard for result 0 must not set 'annotate_for_users'.}}
+ %3 = tosa.rsqrt %2 : (tensor<6xi32>) -> tensor<6xi32>
+ %sharded = shard.shard %3 to %sharding annotate_for_users : tensor<6xi32>
+ return %sharded : tensor<6xi32>
+}
diff --git a/mlir/test/Dialect/Shard/partition.mlir b/mlir/test/Dialect/Shard/partition.mlir
index c2572cc3b987b..0f293a39608e3 100644
--- a/mlir/test/Dialect/Shard/partition.mlir
+++ b/mlir/test/Dialect/Shard/partition.mlir
@@ -14,8 +14,9 @@ func.func @return_sharding(
%sharded = shard.shard %arg0 to %ssharded : tensor<2xf32>
// CHECK-NEXT: [[vsharding:%.*]] = shard.sharding @grid_1d split_axes = {{\[\[}}0]] : !shard.sharding
%r = shard.get_sharding %sharded : tensor<2xf32> -> !shard.sharding
+ %sharded_r = shard.shard %sharded to %ssharded annotate_for_users : tensor<2xf32>
// CHECK-NEXT: return [[ARG]], [[vsharding]] : tensor<1xf32>, !shard.sharding
- return %sharded, %r : tensor<2xf32>, !shard.sharding
+ return %sharded_r, %r : tensor<2xf32>, !shard.sharding
}
// CHECK-LABEL: func @full_replication
@@ -44,7 +45,7 @@ func.func @sharding_triplet(
%ssharded_0 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
%sharded_0 = shard.shard %sharded to %ssharded_0 annotate_for_users : tensor<2xf32>
%ssharded_1 = shard.sharding @grid_1d split_axes = [[]] : !shard.sharding
- %sharded_1 = shard.shard %sharded_0 to %ssharded_1 : tensor<2xf32>
+ %sharded_1 = shard.shard %sharded_0 to %ssharded_1 annotate_for_users : tensor<2xf32>
// CHECK: return %[[ALL_GATHER]] : tensor<2xf32>
return %sharded_1 : tensor<2xf32>
}
@@ -197,9 +198,10 @@ func.func @incomplete_sharding(
// CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
%1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
%s2 = shard.sharding @grid_1d split_axes = [[0]] : !shard.sharding
- %2 = shard.shard %1 to %s2 : tensor<8x16xf32>
+ %2 = shard.shard %1 to %s2 : tensor<8x16xf32>
+ %3 = shard.shard %2 to %s2 annotate_for_users : tensor<8x16xf32>
// CHECK: return %[[RES]] : tensor<4x16xf32>
- return %2 : tensor<8x16xf32>
+ return %3 : tensor<8x16xf32>
}
shard.grid @grid_1d_4(shape = 4)
@@ -301,7 +303,8 @@ func.func @test_reduce_1d(%arg0: tensor<6x6xi32>) -> (tensor<6xi32>) {
%sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
%sharded = shard.shard %arg0 to %sharding annotate_for_users : tensor<6x6xi32>
%4 = tensor.empty() : tensor<6xi32>
- %sharded_out = shard.shard %4 to %sharding : tensor<6xi32>
+ %sharded_4 = shard.shard %4 to %sharding : tensor<6xi32>
+ %sharded_out = shard.shard %sharded_4 to %sharding annotate_for_users : tensor<6xi32>
%sharded_in = shard.shard %sharded to %sharding annotate_for_users : tensor<6x6xi32>
// CHECK: %[[reduced:.*]] = linalg.reduce ins(%arg0 : tensor<3x6xi32>)
%reduced = linalg.reduce ins(%sharded_in : tensor<6x6xi32>) outs(%sharded_out : tensor<6xi32>) dimensions = [1]
More information about the Mlir-commits
mailing list