[Mlir-commits] [mlir] [mlir][shard] Small fixes to partition pass (PR #185050)
Frank Schlimbach
llvmlistbot at llvm.org
Fri Mar 6 09:01:29 PST 2026
https://github.com/fschlimb created https://github.com/llvm/llvm-project/pull/185050
- Empty functions (with no blocks) should be skipped by partition pass, not error-flagged
- fixed ShardingInterfaceImpl of bufferization.materialize_in_destination
Enables https://github.com/llvm/lighthouse/pull/65
>From 6b63108fb27f0ec2a0863cd1039c8beb54ba1842 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 6 Mar 2026 00:49:21 -0800
Subject: [PATCH 1/2] skip empty func.func in partition
---
.../Shard/Transforms/ShardingPropagation.cpp | 16 ++++++++++------
.../Shard/sharding-propagation-failed.mlir | 9 ++++++++-
2 files changed, 18 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
index cff02d4f03143..1e7deda5c6377 100644
--- a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
@@ -364,12 +364,19 @@ struct ShardingPropagation
FunctionOpInterface funcOp = getOperation();
MLIRContext *ctx = funcOp.getContext();
Region ®ion = funcOp.getFunctionBody();
- OpBuilder builder(ctx);
+
+ if (region.empty())
+ return;
+
+ Block &block = region.front();
+ // Nothing to propagate if there is no sharding annotation in the block.
+ if (block.getOps<shard::ShardOp>().empty())
+ return;
+
if (!region.hasOneBlock()) {
funcOp.emitOpError() << "only one block is supported!";
return signalPassFailure();
}
- Block &block = region.front();
LLVM_DEBUG(
DBGS() << "print all the ops' iterator types and indexing maps in the "
@@ -379,10 +386,7 @@ struct ShardingPropagation
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
});
- // Nothing to propagate if there is no sharding annotation in the block.
- if (block.getOps<shard::ShardOp>().empty())
- return;
-
+ OpBuilder builder(ctx);
auto traverse = [&](auto &&range, OpBuilder &builder,
const char *order) -> bool {
for (Operation &op : range) {
diff --git a/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir b/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
index b5eb98d859c36..3459c1c9f6edc 100644
--- a/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
+++ b/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
@@ -1,4 +1,11 @@
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s -verify-diagnostics
+shard.grid @grid(shape = 1) {sym_visibility = "private"}
// expected-error @+1 {{'func.func' op only one block is supported!}}
-func.func private @no_block_function(i64)
+func.func @multi_block_function(%arg0 : tensor<6x6xi32>) -> tensor<6x6xi32> {
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %sharded = shard.shard %arg0 to %sharding : tensor<6x6xi32>
+ cf.br ^bb1
+ ^bb1:
+ return %sharded : tensor<6x6xi32>
+}
>From f2c27fbb06fa7a9d7e2f81319edbb09145e40ffb Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 6 Mar 2026 08:54:28 -0800
Subject: [PATCH 2/2] fixing ShardingInterface of bufferization ops
---
.../Extensions/ShardingExtensions.cpp | 24 +++++++++----------
1 file changed, 11 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp b/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
index 7d6d2a8378813..40a26cf6334a2 100644
--- a/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
@@ -12,22 +12,20 @@
#include "mlir/IR/DialectRegistry.h"
using namespace mlir;
-
-/// Variadic helper function.
-template <typename... OpTypes>
-static void registerAll(MLIRContext *ctx) {
- (OpTypes::template attachInterface<
- shard::IndependentParallelIteratorDomainShardingInterface<OpTypes>>(
- *ctx),
- ...);
-}
+using namespace mlir::bufferization;
+using namespace mlir::shard;
void mlir::bufferization::shard_ext::registerShardingInterfaceExternalModels(
DialectRegistry ®istry) {
- registry.addExtension(+[](MLIRContext *ctx,
- bufferization::BufferizationDialect *dialect) {
- registerAll<bufferization::AllocTensorOp, bufferization::DeallocTensorOp,
- bufferization::MaterializeInDestinationOp>(ctx);
+ registry.addExtension(+[](MLIRContext *ctx, BufferizationDialect *dialect) {
+ AllocTensorOp::attachInterface<
+ IndependentParallelIteratorDomainShardingInterface<AllocTensorOp>>(
+ *ctx);
+ DeallocTensorOp::attachInterface<
+ IndependentParallelIteratorDomainShardingInterface<DeallocTensorOp>>(
+ *ctx);
+ MaterializeInDestinationOp::attachInterface<
+ ElementwiseShardingInterface<MaterializeInDestinationOp>>(*ctx);
});
}
More information about the Mlir-commits
mailing list