[Mlir-commits] [mlir] ec57636 - [mlir][shard, bufferization] Adding sharding extensions for bufferization ops (#177378)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 27 02:02:22 PST 2026
Author: Frank Schlimbach
Date: 2026-01-27T11:02:18+01:00
New Revision: ec57636ae447247683716c00437552645a52ba68
URL: https://github.com/llvm/llvm-project/commit/ec57636ae447247683716c00437552645a52ba68
DIFF: https://github.com/llvm/llvm-project/commit/ec57636ae447247683716c00437552645a52ba68.diff
LOG: [mlir][shard, bufferization] Adding sharding extensions for bufferization ops (#177378)
Adding trivial sharding support for `bufferization.alloc_tensor`,
`bufferization.dealloc_tensor` and
`bufferization.materialize_in_destination`.
include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h -> mlir/include/mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h
---------
Co-authored-by: Adam Siemieniuk <adam.siemieniuk at intel.com>
Added:
mlir/include/mlir/Dialect/Bufferization/Extensions/AllExtensions.h
mlir/include/mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h
mlir/lib/Dialect/Bufferization/Extensions/AllExtensions.cpp
mlir/lib/Dialect/Bufferization/Extensions/CMakeLists.txt
mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
mlir/test/Dialect/Bufferization/shard-partition.mlir
Modified:
mlir/lib/Dialect/Bufferization/CMakeLists.txt
mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
mlir/lib/RegisterAllExtensions.cpp
Removed:
mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Extensions/AllExtensions.h b/mlir/include/mlir/Dialect/Bufferization/Extensions/AllExtensions.h
new file mode 100644
index 0000000000000..e9f87c35a20f9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/Extensions/AllExtensions.h
@@ -0,0 +1,30 @@
+//===- AllExtensions.h - All Bufferization Extensions -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a common entry point for registering all extensions to the
+// bufferization dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_EXTENSIONS_ALLEXTENSIONS_H
+#define MLIR_DIALECT_BUFFERIZATION_EXTENSIONS_ALLEXTENSIONS_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace bufferization {
+/// Register all extensions of the bufferization dialect. This should generally
+/// only be used by tools, or other use cases that really do want *all*
+/// extensions of the dialect. All other cases should prefer to instead register
+/// the specific extensions they intend to take advantage of.
+void registerAllExtensions(DialectRegistry ®istry);
+} // namespace bufferization
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_BUFFERIZATION_EXTENSIONS_ALLEXTENSIONS_H
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h
similarity index 56%
rename from mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h
rename to mlir/include/mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h
index 3e23419eeec07..0fefd54c0255f 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h
@@ -1,4 +1,4 @@
-//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//===- ShardingExtensions.h -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,18 +6,17 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
-#define MLIR_DIALECT_TENSOR_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#ifndef MLIR_DIALECT_BUFFERIZATION_SHARDINGEXTENSIONS_H
+#define MLIR_DIALECT_BUFFERIZATION_SHARDINGEXTENSIONS_H
namespace mlir {
-
class DialectRegistry;
-namespace tensor {
-
+namespace bufferization {
+namespace shard_ext {
void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
-
-} // namespace tensor
+} // namespace shard_ext
+} // namespace bufferization
} // namespace mlir
-#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#endif // MLIR_DIALECT_BUFFERIZATION_SHARDINGEXTENSIONS_H
diff --git a/mlir/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/CMakeLists.txt
index 215ec562c9818..9e4f1de15620c 100644
--- a/mlir/lib/Dialect/Bufferization/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(Extensions)
add_subdirectory(IR)
add_subdirectory(Pipelines)
add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/Bufferization/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Bufferization/Extensions/AllExtensions.cpp
new file mode 100644
index 0000000000000..1cff6352434d7
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Extensions/AllExtensions.cpp
@@ -0,0 +1,16 @@
+//===- AllExtensions.cpp - All Bufferization Dialect Extensions -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h"
+
+using namespace mlir;
+
+void mlir::bufferization::registerAllExtensions(DialectRegistry ®istry) {
+ shard_ext::registerShardingInterfaceExternalModels(registry);
+}
diff --git a/mlir/lib/Dialect/Bufferization/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Extensions/CMakeLists.txt
new file mode 100644
index 0000000000000..73a393e3a2b76
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Extensions/CMakeLists.txt
@@ -0,0 +1,26 @@
+set(LLVM_OPTIONAL_SOURCES
+ AllExtensions.cpp
+ ShardingExtensions.cpp
+ )
+
+add_mlir_extension_library(MLIRBufferizationShardingExtensions
+ ShardingExtensions.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization/Extensions
+
+ LINK_LIBS PUBLIC
+ MLIRBufferizationDialect
+ MLIRIR
+ MLIRShardingInterface
+ )
+
+add_mlir_extension_library(MLIRBufferizationAllExtensions
+ AllExtensions.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization/Extensions
+
+ LINK_LIBS PUBLIC
+ MLIRBufferizationShardingExtensions
+ )
diff --git a/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp b/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
new file mode 100644
index 0000000000000..7d6d2a8378813
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
@@ -0,0 +1,33 @@
+//===- ShardingExtensions.cpp ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/IR/DialectRegistry.h"
+
+using namespace mlir;
+
+/// Variadic helper function.
+template <typename... OpTypes>
+static void registerAll(MLIRContext *ctx) {
+ (OpTypes::template attachInterface<
+ shard::IndependentParallelIteratorDomainShardingInterface<OpTypes>>(
+ *ctx),
+ ...);
+}
+
+void mlir::bufferization::shard_ext::registerShardingInterfaceExternalModels(
+ DialectRegistry ®istry) {
+
+ registry.addExtension(+[](MLIRContext *ctx,
+ bufferization::BufferizationDialect *dialect) {
+ registerAll<bufferization::AllocTensorOp, bufferization::DeallocTensorOp,
+ bufferization::MaterializeInDestinationOp>(ctx);
+ });
+}
diff --git a/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
index ca7287cec55ce..afdb938e15776 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
@@ -1,4 +1,4 @@
-//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//===- ShardingExtensions.cpp ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Tensor/Extensions/ShardingExtensions.h"
#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectRegistry.h"
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 4312100a0c0b0..d5693db23225a 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -37,6 +37,7 @@
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
+#include "mlir/Dialect/Bufferization/Extensions/AllExtensions.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
@@ -73,6 +74,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) {
// Register all conversions to LLVM extensions.
registerConvertArithToEmitCInterface(registry);
arith::registerConvertArithToLLVMInterface(registry);
+ bufferization::registerAllExtensions(registry);
registerConvertComplexToLLVMInterface(registry);
cf::registerConvertControlFlowToLLVMInterface(registry);
func::registerAllExtensions(registry);
diff --git a/mlir/test/Dialect/Bufferization/shard-partition.mlir b/mlir/test/Dialect/Bufferization/shard-partition.mlir
new file mode 100644
index 0000000000000..720753227d99b
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/shard-partition.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt \
+// RUN: --pass-pipeline="builtin.module(func.func(shard-partition))" \
+// RUN: %s | FileCheck %s
+
+shard.grid @grid(shape = 4)
+
+// CHECK-LABEL: func @test_alloc_tensor_op
+// CHECK-SAME: tensor<?x2xf32>
+func.func @test_alloc_tensor_op(%t: tensor<?x8xf32>, %sz: index)
+{
+ %sharding = shard.sharding @grid split_axes = [[], [0]] : !shard.sharding
+ %sharded = shard.shard %t to %sharding : tensor<?x8xf32>
+ // CHECK: bufferization.alloc_tensor(%{{.*}}) : tensor<?x2xf32>
+ %0 = bufferization.alloc_tensor(%sz) : tensor<?x8xf32>
+ %sharded0 = shard.shard %0 to %sharding : tensor<?x8xf32>
+ %sharded1 = shard.shard %sharded to %sharding annotate_for_users : tensor<?x8xf32>
+ // CHECK: bufferization.alloc_tensor() copy(%{{.*}}) {escape = true} : tensor<?x2xf32>
+ %4 = bufferization.alloc_tensor() copy(%sharded1) {escape = true} : tensor<?x8xf32>
+ %sharded4 = shard.shard %4 to %sharding : tensor<?x8xf32>
+ return
+}
+
+// CHECK-LABEL: func @test_dealloc_tensor_op
+// CHECK-SAME: tensor<1xi32>
+func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %sharded = shard.shard %arg0 to %sharding : tensor<4xi32>
+ %sharded1 = shard.shard %sharded to %sharding annotate_for_users : tensor<4xi32>
+ // CHECK: bufferization.dealloc_tensor {{.*}} : tensor<1xi32>
+ bufferization.dealloc_tensor %sharded1 : tensor<4xi32>
+ return
+}
+
+// CHECK-LABEL: func @test_materialize_in_destination_op
+// CHECK-SAME: tensor<2xf32>) -> tensor<2xf32>
+func.func @test_materialize_in_destination_op(
+ %arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
+ %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+ %sharded0 = shard.shard %arg0 to %sharding : tensor<?xf32>
+ %sharded1 = shard.shard %arg1 to %sharding : tensor<?xf32>
+ %sharded2 = shard.shard %arg2 to %sharding : tensor<8xf32>
+ %sharded0_in = shard.shard %sharded0 to %sharding annotate_for_users : tensor<?xf32>
+ %sharded1_in = shard.shard %sharded1 to %sharding annotate_for_users : tensor<?xf32>
+ // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+ %0 = bufferization.materialize_in_destination %sharded0_in in %sharded1_in : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+ %sharded_res0 = shard.shard %0 to %sharding : tensor<?xf32>
+ // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<2xf32>) -> tensor<2xf32>
+ %sharded0_in2 = shard.shard %sharded0 to %sharding annotate_for_users : tensor<?xf32>
+ %sharded2_in1 = shard.shard %sharded2 to %sharding annotate_for_users : tensor<8xf32>
+ %1 = bufferization.materialize_in_destination %sharded0_in2 in %sharded2_in1 : (tensor<?xf32>, tensor<8xf32>) -> tensor<8xf32>
+ %sharded_res1 = shard.shard %1 to %sharding : tensor<8xf32>
+ %sharded_res1_in = shard.shard %sharded_res1 to %sharding annotate_for_users : tensor<8xf32>
+ // CHECK tensor<2xf32>
+ return %sharded_res1_in : tensor<8xf32>
+}
More information about the Mlir-commits
mailing list