[Mlir-commits] [mlir] [MLIR][mesh] moving shardinginterfaceimpl for tensor to tensor extension lib (PR #104913)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 20 03:35:31 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-func
Author: Frank Schlimbach (fschlimb)
<details>
<summary>Changes</summary>
Follow-up to #<!-- -->102598 : as discussed, move tensor sharding implementation into separate tensor extension lib.
@<!-- -->sogartar @<!-- -->yaochengji, could you take a look at this PR?
---
Full diff: https://github.com/llvm/llvm-project/pull/104913.diff
12 Files Affected:
- (modified) mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h (+1-1)
- (added) mlir/include/mlir/Dialect/Tensor/Extensions/AllExtensions.h (+30)
- (added) mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h (+23)
- (renamed) mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h ()
- (modified) mlir/include/mlir/InitAllDialects.h (-2)
- (modified) mlir/include/mlir/InitAllExtensions.h (+2)
- (modified) mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt (-1)
- (modified) mlir/lib/Dialect/Tensor/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp (+16)
- (added) mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt (+26)
- (renamed) mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp (+1-1)
- (modified) mlir/tools/mlir-lsp-server/CMakeLists.txt (+1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
index 9b7abbca5d7622..30d3033209d213 100644
--- a/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
+++ b/mlir/include/mlir/Dialect/Func/Extensions/MeshShardingExtensions.h
@@ -1,4 +1,4 @@
-//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//===- MeshShardingExtensions.h - -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/include/mlir/Dialect/Tensor/Extensions/AllExtensions.h b/mlir/include/mlir/Dialect/Tensor/Extensions/AllExtensions.h
new file mode 100644
index 00000000000000..db0afa858b1fa0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tensor/Extensions/AllExtensions.h
@@ -0,0 +1,30 @@
+//===- AllExtensions.h - All Tensor Extensions ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a common entry point for registering all extensions to the
+// Tensor dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
+#define MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace tensor {
+/// Register all extensions of the Tensor dialect. This should generally only be
+/// used by tools, or other use cases that really do want *all* extensions of
+/// the dialect. All other cases should prefer to instead register the specific
+/// extensions they intend to take advantage of.
+void registerAllExtensions(DialectRegistry ®istry);
+} // namespace tensor
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TENSOR_EXTENSIONS_ALLEXTENSIONS_H
diff --git a/mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h b/mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h
new file mode 100644
index 00000000000000..cfac485b807f2b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h
@@ -0,0 +1,23 @@
+//===- MeshShardingExtensions.h - -------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
+#define MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace tensor {
+
+void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TENSOR_EXTENSIONS_SHARDINGEXTENSIONS_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h
similarity index 100%
rename from mlir/include/mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h
rename to mlir/include/mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index ab81832cdbee55..73dccdb017ee14 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -58,7 +58,6 @@
#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
-#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -182,7 +181,6 @@ inline void registerAllDialects(DialectRegistry ®istry) {
tensor::registerBufferizableOpInterfaceExternalModels(registry);
tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
tensor::registerInferTypeOpInterfaceExternalModels(registry);
- tensor::registerShardingInterfaceExternalModels(registry);
tensor::registerSubsetOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 0adc5e52f2a0e5..dc5d4fbea04f49 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -34,6 +34,7 @@
#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
+#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
@@ -60,6 +61,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) {
registerConvertComplexToLLVMInterface(registry);
cf::registerConvertControlFlowToLLVMInterface(registry);
func::registerAllExtensions(registry);
+ tensor::registerAllExtensions(registry);
registerConvertFuncToLLVMInterface(registry);
index::registerConvertIndexToLLVMInterface(registry);
registerConvertMathToLLVMInterface(registry);
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
index 266fa6fa54557c..afe76b539846a7 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -1,6 +1,5 @@
add_mlir_library(MLIRShardingInterface
ShardingInterface.cpp
- TensorShardingInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
diff --git a/mlir/lib/Dialect/Tensor/CMakeLists.txt b/mlir/lib/Dialect/Tensor/CMakeLists.txt
index 329a6c3e80254f..a834aae8fbf81e 100644
--- a/mlir/lib/Dialect/Tensor/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(Extensions)
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
new file mode 100644
index 00000000000000..93e1a2021857d3
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp
@@ -0,0 +1,16 @@
+//===- AllExtensions.cpp - All Tensor Dialect Extensions ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h"
+
+using namespace mlir;
+
+void mlir::tensor::registerAllExtensions(DialectRegistry ®istry) {
+ registerShardingInterfaceExternalModels(registry);
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
new file mode 100644
index 00000000000000..dba59333666f6b
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt
@@ -0,0 +1,26 @@
+set(LLVM_OPTIONAL_SOURCES
+ AllExtensions.cpp
+ MeshShardingExtensions.cpp
+ )
+
+add_mlir_extension_library(MLIRTensorMeshShardingExtensions
+ MeshShardingExtensions.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
+
+ LINK_LIBS PUBLIC
+ MLIRTensorDialect
+ MLIRIR
+ MLIRShardingInterface
+ )
+
+add_mlir_extension_library(MLIRTensorAllExtensions
+ AllExtensions.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions
+
+ LINK_LIBS PUBLIC
+ MLIRTensorMeshShardingExtensions
+ )
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/TensorShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
similarity index 98%
rename from mlir/lib/Dialect/Mesh/Interfaces/TensorShardingInterfaceImpl.cpp
rename to mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
index 9422dd4a529fd4..f3e72abe7516ee 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/TensorShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp
@@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Mesh/IR/TensorShardingInterfaceImpl.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectRegistry.h"
#include "llvm/Support/Debug.h"
diff --git a/mlir/tools/mlir-lsp-server/CMakeLists.txt b/mlir/tools/mlir-lsp-server/CMakeLists.txt
index 0134b54eef1b07..8ff9cc2f07e8eb 100644
--- a/mlir/tools/mlir-lsp-server/CMakeLists.txt
+++ b/mlir/tools/mlir-lsp-server/CMakeLists.txt
@@ -47,6 +47,7 @@ set(LIBS
MLIRLspServerLib
MLIRParser
MLIRPass
+ MLIRTensorAllExtensions
MLIRTransforms
MLIRTransformUtils
MLIRSupport
``````````
</details>
https://github.com/llvm/llvm-project/pull/104913
More information about the Mlir-commits
mailing list