[Mlir-commits] [mlir] 2a6c8b2 - [mlir][PassIncGen] Refactor how pass registration is generated

River Riddle llvmlistbot at llvm.org
Fri Jul 31 13:26:38 PDT 2020


Author: River Riddle
Date: 2020-07-31T13:20:37-07:00
New Revision: 2a6c8b2e9581ebca4b05d1e64458f2dccf3db61f

URL: https://github.com/llvm/llvm-project/commit/2a6c8b2e9581ebca4b05d1e64458f2dccf3db61f
DIFF: https://github.com/llvm/llvm-project/commit/2a6c8b2e9581ebca4b05d1e64458f2dccf3db61f.diff

LOG: [mlir][PassIncGen] Refactor how pass registration is generated

The current output is a bit clunky and requires including files+macros everywhere, or manually wrapping the file inclusion in a registration function. This revision refactors the pass backend to automatically generate `registerFooPass`/`registerFooPasses` functions that wrap the pass registration. `gen-pass-decls` now takes a `-name` input that specifies a tag name for the group of passes that are being generated. For each pass, the generator now produces a `registerFooPass` where `Foo` is the name of the definition specified in tablegen. It also generates a `registerGroupPasses`, where `Group` is the tag provided via the `-name` input parameter, that registers all of the passes present.

Differential Revision: https://reviews.llvm.org/D84983

Added: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h

Modified: 
    flang/include/flang/Optimizer/CodeGen/CMakeLists.txt
    flang/include/flang/Optimizer/CodeGen/CodeGen.h
    flang/include/flang/Optimizer/Transforms/CMakeLists.txt
    flang/include/flang/Optimizer/Transforms/Passes.h
    mlir/docs/PassManagement.md
    mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h
    mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
    mlir/include/mlir/Conversion/CMakeLists.txt
    mlir/include/mlir/Dialect/Affine/CMakeLists.txt
    mlir/include/mlir/Dialect/Affine/Passes.h
    mlir/include/mlir/Dialect/GPU/CMakeLists.txt
    mlir/include/mlir/Dialect/GPU/Passes.h
    mlir/include/mlir/Dialect/LLVMIR/Transforms/CMakeLists.txt
    mlir/include/mlir/Dialect/Linalg/CMakeLists.txt
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Quant/CMakeLists.txt
    mlir/include/mlir/Dialect/Quant/Passes.h
    mlir/include/mlir/Dialect/SCF/CMakeLists.txt
    mlir/include/mlir/Dialect/SCF/Passes.h
    mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
    mlir/include/mlir/Dialect/SPIRV/Passes.h
    mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt
    mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
    mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
    mlir/include/mlir/InitAllPasses.h
    mlir/include/mlir/Transforms/CMakeLists.txt
    mlir/include/mlir/Transforms/Passes.h
    mlir/tools/mlir-tblgen/PassGen.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/CodeGen/CMakeLists.txt b/flang/include/flang/Optimizer/CodeGen/CMakeLists.txt
index ab6526ee1833..9acf6f89e12f 100644
--- a/flang/include/flang/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/include/flang/Optimizer/CodeGen/CMakeLists.txt
@@ -1,6 +1,6 @@
 
 set(LLVM_TARGET_DEFINITIONS CGPasses.td)
-mlir_tablegen(CGPasses.h.inc -gen-pass-decls)
+mlir_tablegen(CGPasses.h.inc -gen-pass-decls -name OptCodeGen)
 add_public_tablegen_target(FIROptCodeGenPassIncGen)
 
 add_mlir_doc(Passes -gen-pass-doc OptimizerCodeGenPasses ./)

diff  --git a/flang/include/flang/Optimizer/CodeGen/CodeGen.h b/flang/include/flang/Optimizer/CodeGen/CodeGen.h
index 9b968172f348..a90d0a50dac6 100644
--- a/flang/include/flang/Optimizer/CodeGen/CodeGen.h
+++ b/flang/include/flang/Optimizer/CodeGen/CodeGen.h
@@ -28,12 +28,9 @@ std::unique_ptr<mlir::Pass> createFIRToLLVMPass(NameUniquer &uniquer);
 std::unique_ptr<mlir::Pass>
 createLLVMDialectToLLVMPass(llvm::raw_ostream &output);
 
-inline void registerOptCodeGenPasses() {
-  using mlir::Pass;
 // declarative passes
 #define GEN_PASS_REGISTRATION
 #include "flang/Optimizer/CodeGen/CGPasses.h.inc"
-}
 
 } // namespace fir
 

diff  --git a/flang/include/flang/Optimizer/Transforms/CMakeLists.txt b/flang/include/flang/Optimizer/Transforms/CMakeLists.txt
index fde17eb88622..b928991e0a37 100644
--- a/flang/include/flang/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/include/flang/Optimizer/Transforms/CMakeLists.txt
@@ -1,6 +1,6 @@
 
 set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name OptTransform)
 add_public_tablegen_target(FIROptTransformsPassIncGen)
 
 add_mlir_doc(Passes -gen-pass-doc OptimizerTransformPasses ./)

diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 9377c2dc61cc..5e71995736e6 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -46,12 +46,9 @@ std::unique_ptr<mlir::Pass> createMemToRegPass();
 bool canLegallyInline(mlir::Operation *op, mlir::Region *reg,
                       mlir::BlockAndValueMapping &map);
 
-inline void registerOptTransformPasses() {
-using mlir::Pass;
 // declarative passes
 #define GEN_PASS_REGISTRATION
 #include "flang/Optimizer/Transforms/Passes.h.inc"
-}
 
 } // namespace fir
 

diff  --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index 04a4ca0a7b3c..92ca92218219 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -622,18 +622,34 @@ def MyPass : Pass<"my-pass", "ModuleOp"> {
 }
 ```
 
-We can include the generated registration calls via:
+Using the `gen-pass-decls` generator, we can generate the much of the
+boilerplater above automatically. This generator takes as an input a `-name`
+parameter, that provides a tag for the group of passes that are being generated.
+This generator produces two chunks of output:
+
+The first is the code for registering the declarative passes with the global
+registry. For each pass, the generator produces a `registerFooPass` where `Foo`
+is the name of the definition specified in tablegen. It also generates a
+`registerGroupPasses`, where `Group` is the tag provided via the `-name` input
+parameter, that registers all of the passes present.
 
 ```c++
-void registerMyPasses() {
-  // The generated registration is not static, so we need to include this in
-  // a location that we can call into.
 #define GEN_PASS_REGISTRATION
 #include "Passes.h.inc"
+
+void registerMyPasses() {
+  // Register all of our passes.
+  registerMyPasses();
+
+  // Register `MyPass` specifically.
+  registerMyPassPass();
 }
 ```
 
-We can then update the original C++ pass definition:
+The second is a base class for each of the passes, with each containing most of
+the boiler plate related to pass definition. These classes are named in the form
+of `MyPassBase`, where `MyPass` is the name of the definition in tablegen. We
+can update the original C++ pass definition as so:
 
 ```c++
 /// Include the generated base pass class definitions.
@@ -651,6 +667,10 @@ std::unique_ptr<Pass> foo::createMyPass() {
 }
 ```
 
+Using the `gen-pass-doc` generator, we can generate markdown documentation for
+each of our passes. See [Passes.md](Passes.md) for example output of real MLIR
+passes.
+
 ### Tablegen Specification
 
 The `Pass` class is used to begin a new pass definition. This class takes as an

diff  --git a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h
index fdd203a6f6ef..aff5c4ca2c70 100644
--- a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h
+++ b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_
-#define MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_
+#ifndef MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_
+#define MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_
 
 #include <memory>
 
@@ -26,4 +26,4 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertAVX512ToLLVMPass();
 
 } // namespace mlir
 
-#endif // MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_
+#endif // MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_

diff  --git a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
index 4deffafe0ec6..4647cacdd9cd 100644
--- a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
+++ b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
@@ -20,6 +20,7 @@ class Location;
 struct LogicalResult;
 class MLIRContext;
 class OpBuilder;
+class Pass;
 class RewritePattern;
 class Value;
 class ValueRange;
@@ -57,6 +58,12 @@ Value lowerAffineLowerBound(AffineForOp op, OpBuilder &builder);
 /// Emit code that computes the upper bound of the given affine loop using
 /// standard arithmetic operations.
 Value lowerAffineUpperBound(AffineForOp op, OpBuilder &builder);
+
+/// Lowers affine control flow operations (ForStmt, IfStmt and AffineApplyOp)
+/// to equivalent lower-level constructs (flow of basic blocks and arithmetic
+/// primitives).
+std::unique_ptr<Pass> createLowerAffinePass();
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_AFFINETOSTANDARD_AFFINETOSTANDARD_H

diff  --git a/mlir/include/mlir/Conversion/CMakeLists.txt b/mlir/include/mlir/Conversion/CMakeLists.txt
index d4ce2634f450..ae0afc97dc63 100644
--- a/mlir/include/mlir/Conversion/CMakeLists.txt
+++ b/mlir/include/mlir/Conversion/CMakeLists.txt
@@ -1,6 +1,6 @@
 
 set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion)
 add_public_tablegen_target(MLIRConversionPassIncGen)
 
 add_mlir_doc(Passes -gen-pass-doc ConversionPasses ./)

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
new file mode 100644
index 000000000000..87f2c97e766d
--- /dev/null
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -0,0 +1,41 @@
+//===- Passes.h - Conversion Pass Construction and Registration -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_PASSES_H
+#define MLIR_CONVERSION_PASSES_H
+
+#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
+#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
+#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
+#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"
+#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
+#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
+#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
+#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
+#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
+#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
+#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h"
+#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+
+namespace mlir {
+
+/// Generate the code for registering conversion passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Conversion/Passes.h.inc"
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_PASSES_H

diff  --git a/mlir/include/mlir/Dialect/Affine/CMakeLists.txt b/mlir/include/mlir/Dialect/Affine/CMakeLists.txt
index 404c926f60ed..96d951dedf4c 100644
--- a/mlir/include/mlir/Dialect/Affine/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Affine/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_subdirectory(IR)
 
 set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Affine)
 add_public_tablegen_target(MLIRAffinePassIncGen)
 
 add_mlir_doc(Passes -gen-pass-doc AffinePasses ./)

diff  --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index 18b3b790338d..f2cef42a4356 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -14,17 +14,12 @@
 #ifndef MLIR_DIALECT_AFFINE_TRANSFORMS_PASSES_H
 #define MLIR_DIALECT_AFFINE_TRANSFORMS_PASSES_H
 
-#include "mlir/Support/LLVM.h"
-#include <functional>
+#include "mlir/Pass/Pass.h"
 #include <limits>
 
 namespace mlir {
 
 class AffineForOp;
-class FuncOp;
-class ModuleOp;
-class Pass;
-template <typename T> class OperationPass;
 
 /// Creates a simplification pass for affine structures (maps and sets). In
 /// addition, this pass also normalizes memrefs to have the trivial (identity)
@@ -79,6 +74,14 @@ createSuperVectorizePass(ArrayRef<int64_t> virtualVectorSize);
 /// Overload relying on pass options for initialization.
 std::unique_ptr<OperationPass<FuncOp>> createSuperVectorizePass();
 
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Affine/Passes.h.inc"
+
 } // end namespace mlir
 
 #endif // MLIR_DIALECT_AFFINE_RANSFORMS_PASSES_H

diff  --git a/mlir/include/mlir/Dialect/GPU/CMakeLists.txt b/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
index 6c80b4c8e3b9..68313c978842 100644
--- a/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
@@ -12,7 +12,7 @@ mlir_tablegen(ParallelLoopMapperEnums.cpp.inc -gen-enum-defs)
 add_public_tablegen_target(MLIRParallelLoopMapperEnumsGen)
 
 set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name GPU)
 add_public_tablegen_target(MLIRGPUPassIncGen)
 
 add_mlir_doc(Passes -gen-pass-doc GPUPasses ./)

diff  --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h
index bc349061f39f..64b744b6b172 100644
--- a/mlir/include/mlir/Dialect/GPU/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Passes.h
@@ -13,21 +13,23 @@
 #ifndef MLIR_DIALECT_GPU_PASSES_H_
 #define MLIR_DIALECT_GPU_PASSES_H_
 
-#include <memory>
+#include "mlir/Pass/Pass.h"
 
 namespace mlir {
-
-class MLIRContext;
-class ModuleOp;
-template <typename T> class OperationPass;
-class OwningRewritePatternList;
-
 std::unique_ptr<OperationPass<ModuleOp>> createGpuKernelOutliningPass();
 
 /// Collect a set of patterns to rewrite ops within the GPU dialect.
 void populateGpuRewritePatterns(MLIRContext *context,
                                 OwningRewritePatternList &patterns);
 
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/GPU/Passes.h.inc"
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_GPU_PASSES_H_

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/Transforms/CMakeLists.txt
index a2fd81c23e11..a744b0706ffd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
 set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name LLVM)
 add_public_tablegen_target(MLIRLLVMPassIncGen)
 
 add_mlir_doc(Passes -gen-pass-doc LLVMPasses ./)

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
new file mode 100644
index 000000000000..868a0e563510
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
@@ -0,0 +1,26 @@
+//===- Passes.h - LLVM Pass Construction and Registration -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H
+
+#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace LLVM {
+
+/// Generate the code for registering conversion passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H

diff  --git a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt
index 66ac74515ddd..d0edae3979e0 100644
--- a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_subdirectory(IR)
 
 set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Linalg)
 add_public_tablegen_target(MLIRLinalgPassIncGen)
 
 add_mlir_doc(Passes -gen-pass-doc LinalgPasses ./)

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index a5c09b3f75b7..d74714cdaa56 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -13,17 +13,9 @@
 #ifndef MLIR_DIALECT_LINALG_PASSES_H_
 #define MLIR_DIALECT_LINALG_PASSES_H_
 
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/ArrayRef.h"
+#include "mlir/Pass/Pass.h"
 
 namespace mlir {
-class FuncOp;
-class MLIRContext;
-class ModuleOp;
-template <typename T> class OperationPass;
-class OwningRewritePatternList;
-class Pass;
-
 std::unique_ptr<OperationPass<FuncOp>> createLinalgFoldUnitExtentDimsPass();
 
 std::unique_ptr<OperationPass<FuncOp>> createLinalgFusionPass();
@@ -66,6 +58,14 @@ void populateLinalgTensorOpsFusionPatterns(MLIRContext *context,
 void populateLinalgFoldUnitExtentDimsPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns);
 
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_LINALG_PASSES_H_

diff  --git a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
index 1a48e4928b33..177d129a805a 100644
--- a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt
@@ -2,7 +2,7 @@ add_mlir_dialect(QuantOps quant)
 add_mlir_doc(QuantOps -gen-dialect-doc QuantDialect Dialects/)
 
 set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Quant)
 add_public_tablegen_target(MLIRQuantPassIncGen)
 
 add_mlir_doc(Passes -gen-pass-doc QuantPasses ./)

diff  --git a/mlir/include/mlir/Dialect/Quant/Passes.h b/mlir/include/mlir/Dialect/Quant/Passes.h
index b938c9a86b72..090653eabe3f 100644
--- a/mlir/include/mlir/Dialect/Quant/Passes.h
+++ b/mlir/include/mlir/Dialect/Quant/Passes.h
@@ -16,12 +16,9 @@
 #ifndef MLIR_DIALECT_QUANT_PASSES_H
 #define MLIR_DIALECT_QUANT_PASSES_H
 
-#include <memory>
+#include "mlir/Pass/Pass.h"
 
 namespace mlir {
-class FuncOp;
-template <typename T> class OperationPass;
-
 namespace quant {
 
 /// Creates a pass that converts quantization simulation operations (i.e.
@@ -35,6 +32,14 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertSimulatedQuantPass();
 /// destructive and cannot be undone.
 std::unique_ptr<OperationPass<FuncOp>> createConvertConstPass();
 
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Quant/Passes.h.inc"
+
 } // namespace quant
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/SCF/CMakeLists.txt b/mlir/include/mlir/Dialect/SCF/CMakeLists.txt
index 9467b97b384b..546ada0224cf 100644
--- a/mlir/include/mlir/Dialect/SCF/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SCF/CMakeLists.txt
@@ -2,7 +2,7 @@ add_mlir_dialect(SCFOps scf Ops)
 add_mlir_doc(SCFOps -gen-dialect-doc SCFDialect Dialects/)
 
 set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name SCF)
 add_public_tablegen_target(MLIRSCFPassIncGen)
 add_dependencies(mlir-headers MLIRSCFPassIncGen)
 

diff  --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h
index df6037874f2b..7edb2444e87c 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Passes.h
@@ -13,13 +13,10 @@
 #ifndef MLIR_DIALECT_SCF_PASSES_H_
 #define MLIR_DIALECT_SCF_PASSES_H_
 
-#include "llvm/ADT/ArrayRef.h"
-#include <memory>
+#include "mlir/Pass/Pass.h"
 
 namespace mlir {
 
-class Pass;
-
 /// Creates a pass that specializes for loop for unrolling and
 /// vectorization.
 std::unique_ptr<Pass> createForLoopSpecializationPass();
@@ -35,6 +32,14 @@ std::unique_ptr<Pass> createParallelLoopSpecializationPass();
 std::unique_ptr<Pass>
 createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {});
 
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/SCF/Passes.h.inc"
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_SCF_PASSES_H_

diff  --git a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
index 1e0901f07e91..ff078ef9d946 100644
--- a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
@@ -38,7 +38,7 @@ add_public_tablegen_target(MLIRSPIRVTargetAndABIIncGen)
 add_dependencies(mlir-headers MLIRSPIRVTargetAndABIIncGen)
 
 set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name SPIRV)
 add_public_tablegen_target(MLIRSPIRVPassIncGen)
 add_dependencies(mlir-headers MLIRSPIRVPassIncGen)
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h
index df516430be52..dbd2c93a53a5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Passes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h
@@ -50,6 +50,14 @@ std::unique_ptr<OperationPass<spirv::ModuleOp>> createLowerABIAttributesPass();
 /// spv.CompositeInsert into spv.CompositeConstruct.
 std::unique_ptr<OperationPass<spirv::ModuleOp>> createRewriteInsertsPass();
 
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/SPIRV/Passes.h.inc"
+
 } // namespace spirv
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt
index 629b8c0db294..8bbe1cb3fbc6 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
 set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Shape)
 add_public_tablegen_target(MLIRShapeTransformsIncGen)
 
 add_mlir_doc(Passes -gen-pass-doc ShapePasses ./)

diff  --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index e8d2167916d0..543ffc617a5c 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -14,15 +14,9 @@
 #ifndef MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
 #define MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
 
-#include <memory>
+#include "mlir/Pass/Pass.h"
 
 namespace mlir {
-
-class FunctionPass;
-class MLIRContext;
-class OwningRewritePatternList;
-class Pass;
-
 /// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape
 /// dialect to be convertible to Standard. For example, `shape.num_elements` get
 /// transformed to `shape.reduce`, which can be lowered to SCF and Standard.
@@ -42,6 +36,14 @@ void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
                                             MLIRContext *ctx);
 std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
 
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
+
 } // end namespace mlir
 
 #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt
index 413c6523a756..f1cc5d81e0fe 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
 set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Standard)
 add_public_tablegen_target(MLIRStandardTransformsIncGen)
 
 add_mlir_doc(Passes -gen-pass-doc StandardPasses ./)

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index aadc41d2790d..fba5f4b32043 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -15,12 +15,10 @@
 #ifndef MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
 #define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
 
-#include <memory>
+#include "mlir/Pass/Pass.h"
 
 namespace mlir {
 
-class Pass;
-class MLIRContext;
 class OwningRewritePatternList;
 
 /// Creates an instance of the ExpandAtomic pass.
@@ -29,6 +27,14 @@ std::unique_ptr<Pass> createExpandAtomicPass();
 void populateExpandTanhPattern(OwningRewritePatternList &patterns,
                                MLIRContext *ctx);
 
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc"
+
 } // end namespace mlir
 
 #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index a2810f3b270b..7d0a7726ea6c 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -14,38 +14,17 @@
 #ifndef MLIR_INITALLPASSES_H_
 #define MLIR_INITALLPASSES_H_
 
-#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
-#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
-#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
-#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
-#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"
-#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
-#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
-#include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h"
-#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
-#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
-#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
-#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
-#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h"
-#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
-#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
-#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
-#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
-#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
-#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Conversion/Passes.h"
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/GPU/Passes.h"
-#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Quant/Passes.h"
 #include "mlir/Dialect/SCF/Passes.h"
 #include "mlir/Dialect/SPIRV/Passes.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
-#include "mlir/Transforms/LocationSnapshot.h"
 #include "mlir/Transforms/Passes.h"
-#include "mlir/Transforms/ViewOpGraph.h"
-#include "mlir/Transforms/ViewRegionGraph.h"
 
 #include <cstdlib>
 
@@ -59,49 +38,22 @@ namespace mlir {
 // individual passes.
 // The global registry is interesting to interact with the command-line tools.
 inline void registerAllPasses() {
-  // Init general passes
-#define GEN_PASS_REGISTRATION
-#include "mlir/Transforms/Passes.h.inc"
+  // General passes
+  registerTransformsPasses();
 
   // Conversion passes
-#define GEN_PASS_REGISTRATION
-#include "mlir/Conversion/Passes.h.inc"
-
-  // Affine
-#define GEN_PASS_REGISTRATION
-#include "mlir/Dialect/Affine/Passes.h.inc"
-
-  // GPU
-#define GEN_PASS_REGISTRATION
-#include "mlir/Dialect/GPU/Passes.h.inc"
-
-  // Linalg
-#define GEN_PASS_REGISTRATION
-#include "mlir/Dialect/Linalg/Passes.h.inc"
-
-  // LLVM
-#define GEN_PASS_REGISTRATION
-#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
-
-  // Loop
-#define GEN_PASS_REGISTRATION
-#include "mlir/Dialect/SCF/Passes.h.inc"
-
-  // Quant
-#define GEN_PASS_REGISTRATION
-#include "mlir/Dialect/Quant/Passes.h.inc"
-
-  // SPIR-V
-#define GEN_PASS_REGISTRATION
-#include "mlir/Dialect/SPIRV/Passes.h.inc"
-
-  // Standard
-#define GEN_PASS_REGISTRATION
-#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc"
-
-  // Shape
-#define GEN_PASS_REGISTRATION
-#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
+  registerConversionPasses();
+
+  // Dialect passes
+  registerAffinePasses();
+  registerGPUPasses();
+  registerLinalgPasses();
+  LLVM::registerLLVMPasses();
+  quant::registerQuantPasses();
+  registerSCFPasses();
+  registerShapePasses();
+  spirv::registerSPIRVPasses();
+  registerStandardPasses();
 }
 
 } // namespace mlir

diff  --git a/mlir/include/mlir/Transforms/CMakeLists.txt b/mlir/include/mlir/Transforms/CMakeLists.txt
index 706193188edd..f1006e06757b 100644
--- a/mlir/include/mlir/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Transforms/CMakeLists.txt
@@ -1,6 +1,6 @@
 
 set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Transforms)
 add_public_tablegen_target(MLIRTransformsPassIncGen)
 
 add_mlir_doc(Passes -gen-pass-doc GeneralPasses ./)

diff  --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 1ffff1a25a6d..ef8524ebd28a 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -14,19 +14,19 @@
 #ifndef MLIR_TRANSFORMS_PASSES_H
 #define MLIR_TRANSFORMS_PASSES_H
 
-#include "mlir/Support/LLVM.h"
-#include <functional>
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LocationSnapshot.h"
+#include "mlir/Transforms/ViewOpGraph.h"
+#include "mlir/Transforms/ViewRegionGraph.h"
 #include <limits>
 
 namespace mlir {
 
 class AffineForOp;
-class FuncOp;
-class ModuleOp;
-class Pass;
 
-template <typename T>
-class OperationPass;
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
 
 /// Creates an instance of the BufferPlacement pass.
 std::unique_ptr<Pass> createBufferPlacementPass();
@@ -95,6 +95,15 @@ std::unique_ptr<Pass> createSymbolDCEPass();
 /// Creates an interprocedural pass to normalize memrefs to have a trivial
 /// (identity) layout map.
 std::unique_ptr<OperationPass<ModuleOp>> createNormalizeMemRefsPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Transforms/Passes.h.inc"
+
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_PASSES_H

diff  --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index f8998f09a436..c2dcdb8e4ac9 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -14,6 +14,7 @@
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Pass.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
@@ -21,6 +22,11 @@
 using namespace mlir;
 using namespace mlir::tblgen;
 
+static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
+static llvm::cl::opt<std::string>
+    groupName("name", llvm::cl::desc("The name of this group of passes"),
+              llvm::cl::cat(passGenCat));
+
 //===----------------------------------------------------------------------===//
 // GEN: Pass base class generation
 //===----------------------------------------------------------------------===//
@@ -109,36 +115,49 @@ static void emitPassDecls(ArrayRef<Pass> passes, raw_ostream &os) {
 // GEN: Pass registration generation
 //===----------------------------------------------------------------------===//
 
+/// The code snippet used to generate the start of a pass base class.
+///
+/// {0}: The def name of the pass record.
+/// {1}: The argument of the pass.
+/// {2): The summary of the pass.
+/// {3}: The code for constructing the pass.
+const char *const passRegistrationCode = R"(
+//===----------------------------------------------------------------------===//
+// {0} Registration
+//===----------------------------------------------------------------------===//
+
+inline void register{0}Pass() {{
+  ::mlir::registerPass("{1}", "{2}", []() -> std::unique_ptr<::mlir::Pass> {{
+    return {3};
+  });
+}
+)";
+
+/// {0}: The name of the pass group.
+const char *const passGroupRegistrationCode = R"(
+//===----------------------------------------------------------------------===//
+// {0} Registration
+//===----------------------------------------------------------------------===//
+
+inline void register{0}Passes() {{
+)";
+
 /// Emit the code for registering each of the given passes with the global
 /// PassRegistry.
 static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
   os << "#ifdef GEN_PASS_REGISTRATION\n";
   for (const Pass &pass : passes) {
-    os << llvm::formatv("#define GEN_PASS_REGISTRATION_{0}\n",
-                        pass.getDef()->getName());
-  }
-  os << "#endif // GEN_PASS_REGISTRATION\n";
-
-  for (const Pass &pass : passes) {
-    os << llvm::formatv("#ifdef GEN_PASS_REGISTRATION_{0}\n",
-                        pass.getDef()->getName());
-    os << llvm::formatv("::mlir::registerPass(\"{0}\", \"{1}\", []() -> "
-                        "std::unique_ptr<::mlir::Pass> {{ return {2}; });\n",
+    os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
                         pass.getArgument(), pass.getSummary(),
                         pass.getConstructor());
-    os << llvm::formatv("#endif // GEN_PASS_REGISTRATION_{0}\n",
-                        pass.getDef()->getName());
-    os << llvm::formatv("#undef GEN_PASS_REGISTRATION_{0}\n",
-                        pass.getDef()->getName());
   }
 
-  os << "#ifdef GEN_PASS_REGISTRATION\n";
-  for (const Pass &pass : passes) {
-    os << llvm::formatv("#undef GEN_PASS_REGISTRATION_{0}\n",
-                        pass.getDef()->getName());
-  }
-  os << "#endif // GEN_PASS_REGISTRATION\n";
+  os << llvm::formatv(passGroupRegistrationCode, groupName);
+  for (const Pass &pass : passes)
+    os << "  register" << pass.getDef()->getName() << "Pass();\n";
+  os << "}\n";
   os << "#undef GEN_PASS_REGISTRATION\n";
+  os << "#endif // GEN_PASS_REGISTRATION\n";
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list