[Mlir-commits] [mlir] 05781f4 - [mlir][shard] Small fixes to partition pass (#185050)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 9 02:47:49 PDT 2026


Author: Frank Schlimbach
Date: 2026-03-09T10:47:45+01:00
New Revision: 05781f40170d8ab1add63a49fc00fbfd8f70a062

URL: https://github.com/llvm/llvm-project/commit/05781f40170d8ab1add63a49fc00fbfd8f70a062
DIFF: https://github.com/llvm/llvm-project/commit/05781f40170d8ab1add63a49fc00fbfd8f70a062.diff

LOG: [mlir][shard] Small fixes to partition pass (#185050)

- Empty functions (with no blocks) should are skipped by partition pass,
  blocks with more than one continue to get error-flagged
- fixed ShardingInterfaceImpl of bufferization.materialize_in_destination

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
    mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
    mlir/test/Dialect/Shard/sharding-propagation-failed.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp b/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
index 7d6d2a8378813..40a26cf6334a2 100644
--- a/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
@@ -12,22 +12,20 @@
 #include "mlir/IR/DialectRegistry.h"
 
 using namespace mlir;
-
-/// Variadic helper function.
-template <typename... OpTypes>
-static void registerAll(MLIRContext *ctx) {
-  (OpTypes::template attachInterface<
-       shard::IndependentParallelIteratorDomainShardingInterface<OpTypes>>(
-       *ctx),
-   ...);
-}
+using namespace mlir::bufferization;
+using namespace mlir::shard;
 
 void mlir::bufferization::shard_ext::registerShardingInterfaceExternalModels(
     DialectRegistry &registry) {
 
-  registry.addExtension(+[](MLIRContext *ctx,
-                            bufferization::BufferizationDialect *dialect) {
-    registerAll<bufferization::AllocTensorOp, bufferization::DeallocTensorOp,
-                bufferization::MaterializeInDestinationOp>(ctx);
+  registry.addExtension(+[](MLIRContext *ctx, BufferizationDialect *dialect) {
+    AllocTensorOp::attachInterface<
+        IndependentParallelIteratorDomainShardingInterface<AllocTensorOp>>(
+        *ctx);
+    DeallocTensorOp::attachInterface<
+        IndependentParallelIteratorDomainShardingInterface<DeallocTensorOp>>(
+        *ctx);
+    MaterializeInDestinationOp::attachInterface<
+        ElementwiseShardingInterface<MaterializeInDestinationOp>>(*ctx);
   });
 }

diff  --git a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
index cff02d4f03143..1e7deda5c6377 100644
--- a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
@@ -364,12 +364,19 @@ struct ShardingPropagation
     FunctionOpInterface funcOp = getOperation();
     MLIRContext *ctx = funcOp.getContext();
     Region &region = funcOp.getFunctionBody();
-    OpBuilder builder(ctx);
+
+    if (region.empty())
+      return;
+
+    Block &block = region.front();
+    // Nothing to propagate if there is no sharding annotation in the block.
+    if (block.getOps<shard::ShardOp>().empty())
+      return;
+
     if (!region.hasOneBlock()) {
       funcOp.emitOpError() << "only one block is supported!";
       return signalPassFailure();
     }
-    Block &block = region.front();
 
     LLVM_DEBUG(
         DBGS() << "print all the ops' iterator types and indexing maps in the "
@@ -379,10 +386,7 @@ struct ShardingPropagation
             shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
         });
 
-    // Nothing to propagate if there is no sharding annotation in the block.
-    if (block.getOps<shard::ShardOp>().empty())
-      return;
-
+    OpBuilder builder(ctx);
     auto traverse = [&](auto &&range, OpBuilder &builder,
                         const char *order) -> bool {
       for (Operation &op : range) {

diff  --git a/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir b/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
index b5eb98d859c36..3459c1c9f6edc 100644
--- a/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
+++ b/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
@@ -1,4 +1,11 @@
 // RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s -verify-diagnostics
 
+shard.grid @grid(shape = 1) {sym_visibility = "private"}
 // expected-error @+1 {{'func.func' op only one block is supported!}}
-func.func private @no_block_function(i64)
+func.func @multi_block_function(%arg0 : tensor<6x6xi32>) -> tensor<6x6xi32> {
+    %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+    %sharded = shard.shard %arg0 to %sharding : tensor<6x6xi32>
+    cf.br ^bb1
+  ^bb1:
+    return %sharded : tensor<6x6xi32>
+}


        


More information about the Mlir-commits mailing list