[Mlir-commits] [llvm] [mlir] [mlir][shard, bufferization] Adding sharding extensions for bufferization ops (PR #177378)

Frank Schlimbach llvmlistbot at llvm.org
Fri Jan 23 07:13:38 PST 2026


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/177378

>From 71b959e7da4f96dae199c6bc91dd1d76741fb5c4 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 22 Jan 2026 07:16:58 -0800
Subject: [PATCH 1/7] adding sharding extensions for bufferization ops

---
 .../Bufferization/Extensions/AllExtensions.h  | 30 ++++++++++
 .../Extensions/ShardingExtensions.h           | 22 ++++++++
 mlir/lib/Dialect/Bufferization/CMakeLists.txt |  1 +
 .../Extensions/AllExtensions.cpp              | 16 ++++++
 .../Bufferization/Extensions/CMakeLists.txt   | 26 +++++++++
 .../Extensions/ShardingExtensions.cpp         | 32 +++++++++++
 mlir/lib/RegisterAllExtensions.cpp            |  2 +
 .../Bufferization/shard-partition.mlir        | 55 +++++++++++++++++++
 8 files changed, 184 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/Extensions/AllExtensions.h
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h
 create mode 100644 mlir/lib/Dialect/Bufferization/Extensions/AllExtensions.cpp
 create mode 100644 mlir/lib/Dialect/Bufferization/Extensions/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
 create mode 100644 mlir/test/Dialect/Bufferization/shard-partition.mlir

diff --git a/mlir/include/mlir/Dialect/Bufferization/Extensions/AllExtensions.h b/mlir/include/mlir/Dialect/Bufferization/Extensions/AllExtensions.h
new file mode 100644
index 0000000000000..2552a5ecda52d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/Extensions/AllExtensions.h
@@ -0,0 +1,30 @@
+//===- AllExtensions.h - All bufferization Extensions ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a common entry point for registering all extensions to the
+// bufferization dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_EXTENSIONS_ALLEXTENSIONS_H
+#define MLIR_DIALECT_BUFFERIZATION_EXTENSIONS_ALLEXTENSIONS_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace bufferization {
+/// Register all extensions of the bufferization dialect. This should generally only be
+/// used by tools, or other use cases that really do want *all* extensions of
+/// the dialect. All other cases should prefer to instead register the specific
+/// extensions they intend to take advantage of.
+void registerAllExtensions(DialectRegistry &registry);
+} // namespace bufferization
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_BUFFERIZATION_EXTENSIONS_ALLEXTENSIONS_H
diff --git a/mlir/include/mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h b/mlir/include/mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h
new file mode 100644
index 0000000000000..575e1335a9678
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h
@@ -0,0 +1,22 @@
+//===- ShardingInterfaceImpl.h ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_SHARDSHARDINGINTERFACEIMPL_H
+#define MLIR_DIALECT_BUFFERIZATION_SHARDSHARDINGINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace bufferization {
+namespace shard_ext {
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
+} // namespace shard_ext
+} // namespace bufferization
+} // namespace mlir
+
+#endif // MLIR_DIALECT_BUFFERIZATION_SHARDSHARDINGINTERFACEIMPL_H
diff --git a/mlir/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/CMakeLists.txt
index 215ec562c9818..9e4f1de15620c 100644
--- a/mlir/lib/Dialect/Bufferization/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(Extensions)
 add_subdirectory(IR)
 add_subdirectory(Pipelines)
 add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/Bufferization/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Bufferization/Extensions/AllExtensions.cpp
new file mode 100644
index 0000000000000..1cff6352434d7
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Extensions/AllExtensions.cpp
@@ -0,0 +1,16 @@
+//===- AllExtensions.cpp - All Bufferization Dialect Extensions -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h"
+
+using namespace mlir;
+
+void mlir::bufferization::registerAllExtensions(DialectRegistry &registry) {
+  shard_ext::registerShardingInterfaceExternalModels(registry);
+}
diff --git a/mlir/lib/Dialect/Bufferization/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Extensions/CMakeLists.txt
new file mode 100644
index 0000000000000..73a393e3a2b76
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Extensions/CMakeLists.txt
@@ -0,0 +1,26 @@
+set(LLVM_OPTIONAL_SOURCES
+  AllExtensions.cpp
+  ShardingExtensions.cpp
+  )
+
+add_mlir_extension_library(MLIRBufferizationShardingExtensions
+  ShardingExtensions.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization/Extensions
+
+  LINK_LIBS PUBLIC
+  MLIRBufferizationDialect
+  MLIRIR
+  MLIRShardingInterface
+  )
+
+add_mlir_extension_library(MLIRBufferizationAllExtensions
+  AllExtensions.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization/Extensions
+
+  LINK_LIBS PUBLIC
+  MLIRBufferizationShardingExtensions
+  )
diff --git a/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp b/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
new file mode 100644
index 0000000000000..e8a4305698402
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
@@ -0,0 +1,32 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h"
+#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+
+using namespace mlir;
+
+/// Variadic helper function.
+template <typename... OpTypes>
+static void registerAll(MLIRContext *ctx) {
+  (OpTypes::template attachInterface<shard::IndependentParallelIteratorDomainShardingInterface<OpTypes>>(*ctx), ...);
+}
+
+void mlir::bufferization::shard_ext::registerShardingInterfaceExternalModels(
+    DialectRegistry &registry) {
+
+  registry.addExtension(+[](MLIRContext *ctx, bufferization::BufferizationDialect *dialect) {
+        registerAll<
+          bufferization::AllocTensorOp,
+          bufferization::DeallocTensorOp,
+          bufferization::MaterializeInDestinationOp
+        >(ctx);
+  });
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 4312100a0c0b0..f8a22403c1f07 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -38,6 +38,7 @@
 #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
 #include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
+#include "mlir/Dialect/Bufferization/Extensions/AllExtensions.h"
 #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"
 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
@@ -73,6 +74,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
   // Register all conversions to LLVM extensions.
   registerConvertArithToEmitCInterface(registry);
   arith::registerConvertArithToLLVMInterface(registry);
+  bufferization::registerAllExtensions(registry);
   registerConvertComplexToLLVMInterface(registry);
   cf::registerConvertControlFlowToLLVMInterface(registry);
   func::registerAllExtensions(registry);
diff --git a/mlir/test/Dialect/Bufferization/shard-partition.mlir b/mlir/test/Dialect/Bufferization/shard-partition.mlir
new file mode 100644
index 0000000000000..720753227d99b
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/shard-partition.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt \
+// RUN:   --pass-pipeline="builtin.module(func.func(shard-partition))" \
+// RUN:   %s | FileCheck %s
+
+shard.grid @grid(shape = 4)
+
+// CHECK-LABEL: func @test_alloc_tensor_op
+// CHECK-SAME: tensor<?x2xf32>
+func.func @test_alloc_tensor_op(%t: tensor<?x8xf32>, %sz: index)
+{
+  %sharding = shard.sharding @grid split_axes = [[], [0]] : !shard.sharding
+  %sharded = shard.shard %t to %sharding : tensor<?x8xf32>
+  // CHECK: bufferization.alloc_tensor(%{{.*}}) : tensor<?x2xf32>
+  %0 = bufferization.alloc_tensor(%sz) : tensor<?x8xf32>
+  %sharded0 = shard.shard %0 to %sharding : tensor<?x8xf32>
+  %sharded1 = shard.shard %sharded to %sharding annotate_for_users : tensor<?x8xf32>
+  // CHECK: bufferization.alloc_tensor() copy(%{{.*}}) {escape = true} : tensor<?x2xf32>
+  %4 = bufferization.alloc_tensor() copy(%sharded1) {escape = true} : tensor<?x8xf32>
+  %sharded4 = shard.shard %4 to %sharding : tensor<?x8xf32>
+  return
+}
+
+// CHECK-LABEL: func @test_dealloc_tensor_op
+// CHECK-SAME: tensor<1xi32>
+func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
+  %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+  %sharded = shard.shard %arg0 to %sharding : tensor<4xi32>
+  %sharded1 = shard.shard %sharded to %sharding annotate_for_users : tensor<4xi32>
+  // CHECK: bufferization.dealloc_tensor {{.*}} : tensor<1xi32>
+  bufferization.dealloc_tensor %sharded1 : tensor<4xi32>
+  return
+}
+
+// CHECK-LABEL: func @test_materialize_in_destination_op
+// CHECK-SAME: tensor<2xf32>) -> tensor<2xf32>
+func.func @test_materialize_in_destination_op(
+    %arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
+  %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+  %sharded0 = shard.shard %arg0 to %sharding : tensor<?xf32>
+  %sharded1 = shard.shard %arg1 to %sharding : tensor<?xf32>
+  %sharded2 = shard.shard %arg2 to %sharding : tensor<8xf32>
+  %sharded0_in = shard.shard %sharded0 to %sharding annotate_for_users : tensor<?xf32>
+  %sharded1_in = shard.shard %sharded1 to %sharding annotate_for_users : tensor<?xf32>
+  // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  %0 = bufferization.materialize_in_destination %sharded0_in in %sharded1_in : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  %sharded_res0 = shard.shard %0 to %sharding : tensor<?xf32>
+  // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<2xf32>) -> tensor<2xf32>
+  %sharded0_in2 = shard.shard %sharded0 to %sharding annotate_for_users : tensor<?xf32>
+  %sharded2_in1 = shard.shard %sharded2 to %sharding annotate_for_users : tensor<8xf32>
+  %1 = bufferization.materialize_in_destination %sharded0_in2 in %sharded2_in1 : (tensor<?xf32>, tensor<8xf32>) -> tensor<8xf32>
+  %sharded_res1 = shard.shard %1 to %sharding : tensor<8xf32>
+  %sharded_res1_in = shard.shard %sharded_res1 to %sharding annotate_for_users : tensor<8xf32>
+  // CHECK tensor<2xf32>
+  return %sharded_res1_in : tensor<8xf32>
+}

>From a7568c8584be6a31397a6179f9e129111739979a Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Thu, 22 Jan 2026 07:28:50 -0800
Subject: [PATCH 2/7] clang-format

---
 .../Bufferization/Extensions/AllExtensions.h    | 10 +++++-----
 .../Extensions/ShardingExtensions.cpp           | 17 +++++++++--------
 mlir/lib/RegisterAllExtensions.cpp              |  2 +-
 3 files changed, 15 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Extensions/AllExtensions.h b/mlir/include/mlir/Dialect/Bufferization/Extensions/AllExtensions.h
index 2552a5ecda52d..371ff85369081 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Extensions/AllExtensions.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Extensions/AllExtensions.h
@@ -1,4 +1,4 @@
-//===- AllExtensions.h - All bufferization Extensions ------------------*- C++ -*-===//
+//===- AllExtensions.h - All bufferization Extensions -----------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -18,10 +18,10 @@ namespace mlir {
 class DialectRegistry;
 
 namespace bufferization {
-/// Register all extensions of the bufferization dialect. This should generally only be
-/// used by tools, or other use cases that really do want *all* extensions of
-/// the dialect. All other cases should prefer to instead register the specific
-/// extensions they intend to take advantage of.
+/// Register all extensions of the bufferization dialect. This should generally
+/// only be used by tools, or other use cases that really do want *all*
+/// extensions of the dialect. All other cases should prefer to instead register
+/// the specific extensions they intend to take advantage of.
 void registerAllExtensions(DialectRegistry &registry);
 } // namespace bufferization
 
diff --git a/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp b/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
index e8a4305698402..41812e7a53a13 100644
--- a/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
@@ -7,26 +7,27 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
 #include "mlir/IR/DialectRegistry.h"
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 
 using namespace mlir;
 
 /// Variadic helper function.
 template <typename... OpTypes>
 static void registerAll(MLIRContext *ctx) {
-  (OpTypes::template attachInterface<shard::IndependentParallelIteratorDomainShardingInterface<OpTypes>>(*ctx), ...);
+  (OpTypes::template attachInterface<
+       shard::IndependentParallelIteratorDomainShardingInterface<OpTypes>>(
+       *ctx),
+   ...);
 }
 
 void mlir::bufferization::shard_ext::registerShardingInterfaceExternalModels(
     DialectRegistry &registry) {
 
-  registry.addExtension(+[](MLIRContext *ctx, bufferization::BufferizationDialect *dialect) {
-        registerAll<
-          bufferization::AllocTensorOp,
-          bufferization::DeallocTensorOp,
-          bufferization::MaterializeInDestinationOp
-        >(ctx);
+  registry.addExtension(+[](MLIRContext *ctx,
+                            bufferization::BufferizationDialect *dialect) {
+    registerAll<bufferization::AllocTensorOp, bufferization::DeallocTensorOp,
+                bufferization::MaterializeInDestinationOp>(ctx);
   });
 }
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index f8a22403c1f07..d5693db23225a 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -37,8 +37,8 @@
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
 #include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
-#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
 #include "mlir/Dialect/Bufferization/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
 #include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
 #include "mlir/Dialect/Func/Extensions/AllExtensions.h"
 #include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"

>From ddf8bf69ac1117ca8033af5bdb993d87ce2e60ed Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 23 Jan 2026 06:07:02 -0800
Subject: [PATCH 3/7] formatting

---
 .../Dialect/Bufferization/Extensions/ShardingExtensions.h | 8 ++++----
 .../Bufferization/Extensions/ShardingExtensions.cpp       | 2 +-
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h b/mlir/include/mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h
index 575e1335a9678..0fefd54c0255f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Extensions/ShardingExtensions.h
@@ -1,4 +1,4 @@
-//===- ShardingInterfaceImpl.h ----------------------------------------===//
+//===- ShardingExtensions.h -----------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_BUFFERIZATION_SHARDSHARDINGINTERFACEIMPL_H
-#define MLIR_DIALECT_BUFFERIZATION_SHARDSHARDINGINTERFACEIMPL_H
+#ifndef MLIR_DIALECT_BUFFERIZATION_SHARDINGEXTENSIONS_H
+#define MLIR_DIALECT_BUFFERIZATION_SHARDINGEXTENSIONS_H
 
 namespace mlir {
 class DialectRegistry;
@@ -19,4 +19,4 @@ void registerShardingInterfaceExternalModels(DialectRegistry &registry);
 } // namespace bufferization
 } // namespace mlir
 
-#endif // MLIR_DIALECT_BUFFERIZATION_SHARDSHARDINGINTERFACEIMPL_H
+#endif // MLIR_DIALECT_BUFFERIZATION_SHARDINGEXTENSIONS_H
diff --git a/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp b/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
index 41812e7a53a13..7d6d2a8378813 100644
--- a/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
@@ -1,4 +1,4 @@
-//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//===- ShardingExtensions.cpp ---------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

>From 183cae72bddda4b1c1b650bf1fb507901cc05fdb Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 23 Jan 2026 06:16:38 -0800
Subject: [PATCH 4/7] removing spurious ShardingInterfaceImpl.h form tensor

---
 .../Dialect/Tensor/IR/ShardingInterfaceImpl.h | 23 -------------------
 .../Tensor/Extensions/ShardingExtensions.cpp  |  4 ++--
 .../llvm-project-overlay/mlir/BUILD.bazel     |  2 +-
 3 files changed, 3 insertions(+), 26 deletions(-)
 delete mode 100644 mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h
deleted file mode 100644
index 3e23419eeec07..0000000000000
--- a/mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h
+++ /dev/null
@@ -1,23 +0,0 @@
-//===- ShardingInterfaceImpl.h - ------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
-#define MLIR_DIALECT_TENSOR_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
-
-namespace mlir {
-
-class DialectRegistry;
-
-namespace tensor {
-
-void registerShardingInterfaceExternalModels(DialectRegistry &registry);
-
-} // namespace tensor
-} // namespace mlir
-
-#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
index ca7287cec55ce..07c79a7073cc2 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
@@ -1,4 +1,4 @@
-//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//===- ShardingExtensions.cpp ---------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -8,7 +8,7 @@
 
 #include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
 #include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/Extensions/ShardingExtensions.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/DialectRegistry.h"
 
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 73adfa40d831f..b9de832e87c2e 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7647,7 +7647,7 @@ cc_library(
         "lib/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.cpp",
     ],
     hdrs = [
-        "include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h",
+        "include/mlir/Dialect/Tensor/Extensions/ShardingExtensions.h",
         "include/mlir/Dialect/Tensor/IR/Tensor.h",
         "include/mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h",
     ],

>From 0ac483d61d5d6b45c470d4fdc97624757cdd9041 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 23 Jan 2026 06:19:51 -0800
Subject: [PATCH 5/7] clang-format

---
 mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
index 07c79a7073cc2..afdb938e15776 100644
--- a/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp
@@ -6,9 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Tensor/Extensions/ShardingExtensions.h"
 #include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h"
 #include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h"
-#include "mlir/Dialect/Tensor/Extensions/ShardingExtensions.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/DialectRegistry.h"
 

>From 05b9b5b2edb4b4672d3238752646b53092a14f48 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 23 Jan 2026 07:08:27 -0800
Subject: [PATCH 6/7] bazel

---
 utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index b9de832e87c2e..afd37bd93dc4d 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -13997,9 +13997,15 @@ cc_library(
     srcs = glob(
         [
             "lib/Dialect/Bufferization/Transforms/*.cpp",
+            "lib/Dialect/Bufferization/Extensions/*.cpp",
         ],
     ),
-    hdrs = glob(["include/mlir/Dialect/Bufferization/Transforms/*.h"]),
+    hdrs = glob(
+        [
+            "include/mlir/Dialect/Bufferization/Transforms/*.h",
+            "include/mlir/Dialect/Bufferization/Extensions/*s.h",
+        ]
+    ),
     includes = ["include"],
     deps = [
         ":AllocationOpInterface",

>From df51445f526a1af67f430f86e4fe1e20354d721b Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 23 Jan 2026 07:13:20 -0800
Subject: [PATCH 7/7] bazel

---
 utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index afd37bd93dc4d..6bd11625a14f1 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -14003,7 +14003,7 @@ cc_library(
     hdrs = glob(
         [
             "include/mlir/Dialect/Bufferization/Transforms/*.h",
-            "include/mlir/Dialect/Bufferization/Extensions/*s.h",
+            "include/mlir/Dialect/Bufferization/Extensions/*.h",
         ]
     ),
     includes = ["include"],
@@ -14025,6 +14025,7 @@ cc_library(
         ":MemRefUtils",
         ":Pass",
         ":SCFDialect",
+        "ShardingInterface",
         ":SideEffectInterfaces",
         ":SubsetOpInterface",
         ":TensorDialect",



More information about the Mlir-commits mailing list