[Mlir-commits] [mlir] [MLIR][mesh] moving shardinginterfaceimpl for tensor to tensor extension lib (PR #104913)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 20 03:35:31 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-func

Author: Frank Schlimbach (fschlimb)

<details>
<summary>Changes</summary>

Follow-up to #<!-- -->102598 : as discussed, move tensor sharding implementation into separate tensor extension lib.

@<!-- -->sogartar @<!-- -->yaochengji, could you take a look at this PR?

---
Full diff: https://github.com/llvm/llvm-project/pull/104913.diff


12 Files Affected:

- (modified) mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h (+1-1) 
- (added) mlir/include/mlir/Dialect/Tensor/Extensions/AllExtensions.h (+30) 
- (added) mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h (+23) 
- (renamed) mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h () 
- (modified) mlir/include/mlir/InitAllDialects.h (-2) 
- (modified) mlir/include/mlir/InitAllExtensions.h (+2) 
- (modified) mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt (-1) 
- (modified) mlir/lib/Dialect/Tensor/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp (+16) 
- (added) mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt (+26) 
- (renamed) mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp (+1-1) 
- (modified) mlir/tools/mlir-lsp-server/CMakeLists.txt (+1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
index 9b7abbca5d7622..30d3033209d213 100644
--- a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
+++ b/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
@@ -1,4 +1,4 @@
-//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//===- MeshShardingExtensions.h - -----------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Dialect/Tensor/Extensions/AllExtensions.h b/mlir/include/mlir/Dialect/Tensor/Extensions/AllExtensions.h
new file mode 100644
index 00000000000000..db0afa858b1fa0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tensor/Extensions/AllExtensions.h
@@ -0,0 +1,30 @@
+//===- AllExtensions.h - All Tensor Extensions ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a common entry point for registering all extensions to the
+// Tensor dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
+#define MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace tensor {
+/// Register all extensions of the Tensor dialect. This should generally only be
+/// used by tools, or other use cases that really do want *all* extensions of
+/// the dialect. All other cases should prefer to instead register the specific
+/// extensions they intend to take advantage of.
+void registerAllExtensions(DialectRegistry &registry);
+} // namespace tensor
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
diff --git a/mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h
new file mode 100644
index 00000000000000..cfac485b807f2b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h
@@ -0,0 +1,23 @@
+//===- MeshShardingExtensions.h - -------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
+#define MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace tensor {
+
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h
similarity index 100%
rename from mlir/include/mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h
rename to mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index ab81832cdbee55..73dccdb017ee14 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -58,7 +58,6 @@
 #include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -182,7 +181,6 @@ inline void registerAllDialects(DialectRegistry &registry) {
   tensor::registerBufferizableOpInterfaceExternalModels(registry);
   tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
   tensor::registerInferTypeOpInterfaceExternalModels(registry);
-  tensor::registerShardingInterfaceExternalModels(registry);
   tensor::registerSubsetOpInterfaceExternalModels(registry);
   tensor::registerTilingInterfaceExternalModels(registry);
   tensor::registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 0adc5e52f2a0e5..dc5d4fbea04f49 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -34,6 +34,7 @@
 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
 #include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
+#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
 #include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
 #include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
@@ -60,6 +61,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertComplexToLLVMInterface(registry);
   cf::registerConvertControlFlowToLLVMInterface(registry);
   func::registerAllExtensions(registry);
+  tensor::registerAllExtensions(registry);
   registerConvertFuncToLLVMInterface(registry);
   index::registerConvertIndexToLLVMInterface(registry);
   registerConvertMathToLLVMInterface(registry);
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
index 266fa6fa54557c..afe76b539846a7 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_mlir_library(MLIRShardingInterface
   ShardingInterface.cpp
-  TensorShardingInterfaceImpl.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
diff --git a/mlir/lib/Dialect/Tensor/CMakeLists.txt b/mlir/lib/Dialect/Tensor/CMakeLists.txt
index 329a6c3e80254f..a834aae8fbf81e 100644
--- a/mlir/lib/Dialect/Tensor/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(Extensions)
 add_subdirectory(IR)
 add_subdirectory(Transforms)
 add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
new file mode 100644
index 00000000000000..93e1a2021857d3
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
@@ -0,0 +1,16 @@
+//===- AllExtensions.cpp - All Tensor Dialect Extensions ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h"
+
+using namespace mlir;
+
+void mlir::tensor::registerAllExtensions(DialectRegistry &registry) {
+  registerShardingInterfaceExternalModels(registry);
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
new file mode 100644
index 00000000000000..dba59333666f6b
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
@@ -0,0 +1,26 @@
+set(LLVM_OPTIONAL_SOURCES
+  AllExtensions.cpp
+  MeshShardingExtensions.cpp
+  )
+
+add_mlir_extension_library(MLIRTensorMeshShardingExtensions
+  MeshShardingExtensions.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
+
+  LINK_LIBS PUBLIC
+  MLIRTensorDialect
+  MLIRIR
+  MLIRShardingInterface
+  )
+
+add_mlir_extension_library(MLIRTensorAllExtensions
+  AllExtensions.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
+
+  LINK_LIBS PUBLIC
+  MLIRTensorMeshShardingExtensions
+  )
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/TensorShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
similarity index 98%
rename from mlir/lib/Dialect/Mesh/Interfaces/TensorShardingInterfaceImpl.cpp
rename to mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index 9422dd4a529fd4..f3e72abe7516ee 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/TensorShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -6,9 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/DialectRegistry.h"
 #include "llvm/Support/Debug.h"
diff --git a/mlir/tools/mlir-lsp-server/CMakeLists.txt b/mlir/tools/mlir-lsp-server/CMakeLists.txt
index 0134b54eef1b07..8ff9cc2f07e8eb 100644
--- a/mlir/tools/mlir-lsp-server/CMakeLists.txt
+++ b/mlir/tools/mlir-lsp-server/CMakeLists.txt
@@ -47,6 +47,7 @@ set(LIBS
   MLIRLspServerLib
   MLIRParser
   MLIRPass
+  MLIRTensorAllExtensions
   MLIRTransforms
   MLIRTransformUtils
   MLIRSupport

``````````

</details>


https://github.com/llvm/llvm-project/pull/104913


More information about the Mlir-commits mailing list