[Mlir-commits] [mlir] 722f909 - [mlir][Pass][NFC] Replace usages of ModulePass with OperationPass<ModuleOp>

River Riddle llvmlistbot at llvm.org
Tue Apr 7 14:10:10 PDT 2020


Author: River Riddle
Date: 2020-04-07T14:08:52-07:00
New Revision: 722f909f7aa1d5ab21f68eb8ce1baf109cc5bb13

URL: https://github.com/llvm/llvm-project/commit/722f909f7aa1d5ab21f68eb8ce1baf109cc5bb13
DIFF: https://github.com/llvm/llvm-project/commit/722f909f7aa1d5ab21f68eb8ce1baf109cc5bb13.diff

LOG: [mlir][Pass][NFC] Replace usages of ModulePass with OperationPass<ModuleOp>

ModulePass doesn't provide any special utilities and thus doesn't give enough benefit to warrant a special pass class. This revision replaces all usages with the more general OperationPass.

Differential Revision: https://reviews.llvm.org/D77339

Added: 
    

Modified: 
    mlir/docs/Tutorials/Toy/Ch-6.md
    mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
    mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
    mlir/include/mlir/Pass/Pass.h
    mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
    mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
    mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
    mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
    mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
    mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
    mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
    mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
    mlir/lib/Transforms/OpStats.cpp
    mlir/lib/Transforms/ViewOpGraph.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/test/lib/IR/TestFunc.cpp
    mlir/test/lib/IR/TestSideEffects.cpp
    mlir/test/lib/IR/TestSymbolUses.cpp
    mlir/test/lib/Pass/TestPassManager.cpp
    mlir/test/lib/Transforms/TestAllReduceLowering.cpp
    mlir/test/lib/Transforms/TestCallGraph.cpp
    mlir/test/lib/Transforms/TestOpaqueLoc.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md
index 0444d2a7690a..e1dfc0039f8e 100644
--- a/mlir/docs/Tutorials/Toy/Ch-6.md
+++ b/mlir/docs/Tutorials/Toy/Ch-6.md
@@ -105,7 +105,7 @@ We want to completely lower to LLVM, so we use a `FullConversion`. This ensures
 that only legal operations will remain after the conversion.
 
 ```c++
-  mlir::ModuleOp module = getModule();
+  mlir::ModuleOp module = getOperation();
   if (mlir::failed(mlir::applyFullConversion(module, target, patterns,
                                              &typeConverter)))
     signalPassFailure();

diff  --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index f6dcba229276..99465d3201e5 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -153,12 +153,13 @@ class PrintOpLowering : public ConversionPattern {
 //===----------------------------------------------------------------------===//
 
 namespace {
-struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> {
-  void runOnModule() final;
+struct ToyToLLVMLoweringPass
+    : public OperationPass<ToyToLLVMLoweringPass, ModuleOp> {
+  void runOnOperation() final;
 };
 } // end anonymous namespace
 
-void ToyToLLVMLoweringPass::runOnModule() {
+void ToyToLLVMLoweringPass::runOnOperation() {
   // The first thing to define is the conversion target. This will define the
   // final target for this lowering. For this lowering, we are only targeting
   // the LLVM dialect.
@@ -191,7 +192,7 @@ void ToyToLLVMLoweringPass::runOnModule() {
 
   // We want to completely lower to LLVM, so we use a `FullConversion`. This
   // ensures that only legal operations will remain after the conversion.
-  auto module = getModule();
+  auto module = getOperation();
   if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
     signalPassFailure();
 }

diff  --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index f6dcba229276..99465d3201e5 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -153,12 +153,13 @@ class PrintOpLowering : public ConversionPattern {
 //===----------------------------------------------------------------------===//
 
 namespace {
-struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> {
-  void runOnModule() final;
+struct ToyToLLVMLoweringPass
+    : public OperationPass<ToyToLLVMLoweringPass, ModuleOp> {
+  void runOnOperation() final;
 };
 } // end anonymous namespace
 
-void ToyToLLVMLoweringPass::runOnModule() {
+void ToyToLLVMLoweringPass::runOnOperation() {
   // The first thing to define is the conversion target. This will define the
   // final target for this lowering. For this lowering, we are only targeting
   // the LLVM dialect.
@@ -191,7 +192,7 @@ void ToyToLLVMLoweringPass::runOnModule() {
 
   // We want to completely lower to LLVM, so we use a `FullConversion`. This
   // ensures that only legal operations will remain after the conversion.
-  auto module = getModule();
+  auto module = getOperation();
   if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
     signalPassFailure();
 }

diff  --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 80c4ddfeae62..c1eec4f4706a 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -341,24 +341,9 @@ template <typename T> struct FunctionPass : public OperationPass<T, FuncOp> {
       runOnFunction();
   }
 
-  /// Return the current module being transformed.
+  /// Return the current function being transformed.
   FuncOp getFunction() { return this->getOperation(); }
 };
-
-/// A model for providing module pass specific utilities.
-///
-/// Derived module passes are expected to provide the following:
-///   - A 'void runOnModule()' method.
-template <typename T> struct ModulePass : public OperationPass<T, ModuleOp> {
-  /// The polymorphic API that runs the pass over the currently held module.
-  virtual void runOnModule() = 0;
-
-  /// The polymorphic API that runs the pass over the currently held operation.
-  void runOnOperation() final { runOnModule(); }
-
-  /// Return the current module being transformed.
-  ModuleOp getModule() { return this->getOperation(); }
-};
 } // end namespace mlir
 
 #endif // MLIR_PASS_PASS_H

diff  --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
index 91f3cc933a02..08b187fc835e 100644
--- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
+++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp
@@ -163,16 +163,17 @@ void mlir::populateAVX512ToLLVMConversionPatterns(
 }
 
 namespace {
-struct ConvertAVX512ToLLVMPass : public ModulePass<ConvertAVX512ToLLVMPass> {
+struct ConvertAVX512ToLLVMPass
+    : public OperationPass<ConvertAVX512ToLLVMPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertAVX512ToLLVM
 #include "mlir/Conversion/Passes.h.inc"
 
-  void runOnModule() override;
+  void runOnOperation() override;
 };
 } // namespace
 
-void ConvertAVX512ToLLVMPass::runOnModule() {
+void ConvertAVX512ToLLVMPass::runOnOperation() {
   // Convert to the LLVM IR dialect.
   OwningRewritePatternList patterns;
   LLVMTypeConverter converter(&getContext());
@@ -186,8 +187,8 @@ void ConvertAVX512ToLLVMPass::runOnModule() {
   target.addIllegalDialect<avx512::AVX512Dialect>();
   target.addDynamicallyLegalOp<FuncOp>(
       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
-  if (failed(
-          applyPartialConversion(getModule(), target, patterns, &converter))) {
+  if (failed(applyPartialConversion(getOperation(), target, patterns,
+                                    &converter))) {
     signalPassFailure();
   }
 }

diff  --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
index 38c092a2eaf0..71fe129d3875 100644
--- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
+++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
@@ -61,7 +61,7 @@ namespace {
 ///
 /// Intermediate data structures are allocated on the stack.
 class GpuLaunchFuncToCudaCallsPass
-    : public ModulePass<GpuLaunchFuncToCudaCallsPass> {
+    : public OperationPass<GpuLaunchFuncToCudaCallsPass, ModuleOp> {
 private:
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertGpuLaunchFuncToCudaCalls
@@ -126,20 +126,19 @@ class GpuLaunchFuncToCudaCallsPass
 
 public:
   // Run the dialect converter on the module.
-  void runOnModule() override {
+  void runOnOperation() override {
     // Cache the LLVMDialect for the current module.
     llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
     // Cache the used LLVM types.
     initializeCachedTypes();
 
-    getModule().walk([this](mlir::gpu::LaunchFuncOp op) {
-      translateGpuLaunchCalls(op);
-    });
+    getOperation().walk(
+        [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
 
     // GPU kernel modules are no longer necessary since we have a global
     // constant with the CUBIN data.
     for (auto m :
-         llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
+         llvm::make_early_inc_range(getOperation().getOps<gpu::GPUModuleOp>()))
       m.erase();
   }
 
@@ -160,7 +159,7 @@ class GpuLaunchFuncToCudaCallsPass
 // The types in comments give the actual types expected/returned but the API
 // uses void pointers. This is fine as they have the same linkage in C.
 void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
-  ModuleOp module = getModule();
+  ModuleOp module = getOperation();
   OpBuilder builder(module.getBody()->getTerminator());
   if (!module.lookupSymbol(cuModuleLoadName)) {
     builder.create<LLVM::LLVMFuncOp>(
@@ -391,7 +390,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
                                                builder.getI32IntegerAttr(0));
   // Create an LLVM global with CUBIN extracted from the kernel annotation and
   // obtain a pointer to the first byte in it.
-  auto kernelModule = getModule().lookupSymbol<gpu::GPUModuleOp>(
+  auto kernelModule = getOperation().lookupSymbol<gpu::GPUModuleOp>(
       launchOp.getKernelModuleName());
   assert(kernelModule && "expected a kernel module");
 
@@ -412,7 +411,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
   // in the called helper function.
   auto cuModule = allocatePointer(builder, loc);
   auto cuModuleLoad =
-      getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
+      getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
                                builder.getSymbolRefAttr(cuModuleLoad),
                                ArrayRef<Value>{cuModule, data});
@@ -423,20 +422,20 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
   auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder);
   auto cuFunction = allocatePointer(builder, loc);
   auto cuModuleGetFunction =
-      getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);
+      getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);
   builder.create<LLVM::CallOp>(
       loc, ArrayRef<Type>{getCUResultType()},
       builder.getSymbolRefAttr(cuModuleGetFunction),
       ArrayRef<Value>{cuFunction, cuOwningModuleRef, kernelName});
   // Grab the global stream needed for execution.
   auto cuGetStreamHelper =
-      getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
+      getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
   auto cuStream = builder.create<LLVM::CallOp>(
       loc, ArrayRef<Type>{getPointerType()},
       builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value>{});
   // Invoke the function with required arguments.
   auto cuLaunchKernel =
-      getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
+      getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
   auto cuFunctionRef =
       builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
   auto paramsArray = setupParamsArray(launchOp, builder);
@@ -458,7 +457,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
                       nullpointer /* extra */});
   // Sync on the stream to make it synchronous.
   auto cuStreamSync =
-      getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
+      getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
   builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
                                builder.getSymbolRefAttr(cuStreamSync),
                                ArrayRef<Value>(cuStream.getResult(0)));

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
index 1102ef182c5f..edee5025ded9 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
@@ -33,18 +33,18 @@ namespace {
 /// replace it).
 ///
 /// 2) Lower the body of the spirv::ModuleOp.
-struct GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
+struct GPUToSPIRVPass : public OperationPass<GPUToSPIRVPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertGpuToSPIRV
 #include "mlir/Conversion/Passes.h.inc"
 
-  void runOnModule() override;
+  void runOnOperation() override;
 };
 } // namespace
 
-void GPUToSPIRVPass::runOnModule() {
+void GPUToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  ModuleOp module = getModule();
+  ModuleOp module = getOperation();
 
   SmallVector<Operation *, 1> kernelModules;
   OpBuilder builder(context);

diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
index 823860ba2589..cbcfd741d9f8 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -38,13 +38,13 @@ namespace {
 /// function and attaching binary data and entry point name as an attributes to
 /// created vulkan launch call op.
 class ConvertGpuLaunchFuncToVulkanLaunchFunc
-    : public ModulePass<ConvertGpuLaunchFuncToVulkanLaunchFunc> {
+    : public OperationPass<ConvertGpuLaunchFuncToVulkanLaunchFunc, ModuleOp> {
 public:
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertGpuLaunchFuncToVulkanLaunchFunc
 #include "mlir/Conversion/Passes.h.inc"
 
-  void runOnModule() override;
+  void runOnOperation() override;
 
 private:
   /// Creates a SPIR-V binary shader from the given `module` using
@@ -68,14 +68,13 @@ class ConvertGpuLaunchFuncToVulkanLaunchFunc
   /// operand is unsupported by Vulkan runtime.
   LogicalResult declareVulkanLaunchFunc(Location loc,
                                         gpu::LaunchFuncOp launchOp);
-
 };
 
 } // anonymous namespace
 
-void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() {
+void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnOperation() {
   bool done = false;
-  getModule().walk([this, &done](gpu::LaunchFuncOp op) {
+  getOperation().walk([this, &done](gpu::LaunchFuncOp op) {
     if (done) {
       op.emitError("should only contain one 'gpu::LaunchFuncOp' op");
       return signalPassFailure();
@@ -86,17 +85,17 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() {
 
   // Erase `gpu::GPUModuleOp` and `spirv::Module` operations.
   for (auto gpuModule :
-       llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
+       llvm::make_early_inc_range(getOperation().getOps<gpu::GPUModuleOp>()))
     gpuModule.erase();
 
   for (auto spirvModule :
-       llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>()))
+       llvm::make_early_inc_range(getOperation().getOps<spirv::ModuleOp>()))
     spirvModule.erase();
 }
 
 LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
     Location loc, gpu::LaunchFuncOp launchOp) {
-  OpBuilder builder(getModule().getBody()->getTerminator());
+  OpBuilder builder(getOperation().getBody()->getTerminator());
   // TODO: Workgroup size is written into the kernel. So to properly modelling
   // vulkan launch, we cannot have the local workgroup size configuration here.
   SmallVector<Type, 8> vulkanLaunchTypes{launchOp.getOperandTypes()};
@@ -138,7 +137,7 @@ LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::createBinaryShader(
 
 void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
     gpu::LaunchFuncOp launchOp) {
-  ModuleOp module = getModule();
+  ModuleOp module = getOperation();
   OpBuilder builder(launchOp);
   Location loc = launchOp.getLoc();
 

diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index ebc8ded483ff..2daa13085bcb 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -58,7 +58,7 @@ namespace {
 /// * deinitVulkan         -- deinitializes vulkan runtime
 ///
 class VulkanLaunchFuncToVulkanCallsPass
-    : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> {
+    : public OperationPass<VulkanLaunchFuncToVulkanCallsPass, ModuleOp> {
 private:
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls
@@ -150,7 +150,7 @@ class VulkanLaunchFuncToVulkanCallsPass
   LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank);
 
 public:
-  void runOnModule() override;
+  void runOnOperation() override;
 
 private:
   LLVM::LLVMDialect *llvmDialect;
@@ -169,18 +169,18 @@ class VulkanLaunchFuncToVulkanCallsPass
 
 } // anonymous namespace
 
-void VulkanLaunchFuncToVulkanCallsPass::runOnModule() {
+void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
   initializeCachedTypes();
 
   // Collect SPIR-V attributes such as `spirv_blob` and
   // `spirv_entry_point_name`.
-  getModule().walk([this](LLVM::CallOp op) {
+  getOperation().walk([this](LLVM::CallOp op) {
     if (isVulkanLaunchCallOp(op))
       collectSPIRVAttributes(op);
   });
 
   // Convert vulkan launch call op into a sequence of Vulkan runtime calls.
-  getModule().walk([this](LLVM::CallOp op) {
+  getOperation().walk([this](LLVM::CallOp op) {
     if (isCInterfaceVulkanLaunchCallOp(op))
       translateVulkanLaunchCall(op);
   });
@@ -278,7 +278,7 @@ VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
 }
 
 void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
-  ModuleOp module = getModule();
+  ModuleOp module = getOperation();
   OpBuilder builder(module.getBody()->getTerminator());
 
   if (!module.lookupSymbol(kSetEntryPoint)) {

diff  --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 07c8111941e4..99f106e29de7 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -561,17 +561,18 @@ void mlir::populateLinalgToLLVMConversionPatterns(
 }
 
 namespace {
-struct ConvertLinalgToLLVMPass : public ModulePass<ConvertLinalgToLLVMPass> {
+struct ConvertLinalgToLLVMPass
+    : public OperationPass<ConvertLinalgToLLVMPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertLinalgToLLVM
 #include "mlir/Conversion/Passes.h.inc"
 
-  void runOnModule() override;
+  void runOnOperation() override;
 };
 } // namespace
 
-void ConvertLinalgToLLVMPass::runOnModule() {
-  auto module = getModule();
+void ConvertLinalgToLLVMPass::runOnOperation() {
+  auto module = getOperation();
 
   // Convert to the LLVM IR dialect using the converter defined above.
   OwningRewritePatternList patterns;

diff  --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
index 0962746c486a..4b66063b88eb 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
@@ -16,18 +16,18 @@ using namespace mlir;
 
 namespace {
 /// A pass converting MLIR Linalg ops into SPIR-V ops.
-class LinalgToSPIRVPass : public ModulePass<LinalgToSPIRVPass> {
+class LinalgToSPIRVPass : public OperationPass<LinalgToSPIRVPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertLinalgToSPIRV
 #include "mlir/Conversion/Passes.h.inc"
 
-  void runOnModule() override;
+  void runOnOperation() override;
 };
 } // namespace
 
-void LinalgToSPIRVPass::runOnModule() {
+void LinalgToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  ModuleOp module = getModule();
+  ModuleOp module = getOperation();
 
   auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
   std::unique_ptr<ConversionTarget> target =

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 1e127a0a884e..ef5dabf2ff88 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2847,7 +2847,7 @@ LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands,
 
 namespace {
 /// A pass converting MLIR operations into the LLVM IR dialect.
-struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
+struct LLVMLoweringPass : public OperationPass<LLVMLoweringPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertStandardToLLVM
 #include "mlir/Conversion/Passes.h.inc"
@@ -2863,16 +2863,16 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
   LLVMLoweringPass(const LLVMLoweringPass &pass) {}
 
   /// Run the dialect converter on the module.
-  void runOnModule() override {
+  void runOnOperation() override {
     if (useBarePtrCallConv && emitCWrappers) {
-      getModule().emitError()
+      getOperation().emitError()
           << "incompatible conversion options: bare-pointer calling convention "
              "and C wrapper emission";
       signalPassFailure();
       return;
     }
 
-    ModuleOp m = getModule();
+    ModuleOp m = getOperation();
 
     LLVMTypeConverterCustomization customs;
     customs.funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
index ab7dd8546995..86c8cd17433c 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
@@ -22,18 +22,18 @@ using namespace mlir;
 namespace {
 /// A pass converting MLIR Standard operations into the SPIR-V dialect.
 class ConvertStandardToSPIRVPass
-    : public ModulePass<ConvertStandardToSPIRVPass> {
+    : public OperationPass<ConvertStandardToSPIRVPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertStandardToSPIRV
 #include "mlir/Conversion/Passes.h.inc"
 
-  void runOnModule() override;
+  void runOnOperation() override;
 };
 } // namespace
 
-void ConvertStandardToSPIRVPass::runOnModule() {
+void ConvertStandardToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  ModuleOp module = getModule();
+  ModuleOp module = getOperation();
 
   auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
   std::unique_ptr<ConversionTarget> target =

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d5a4f86d2ca9..b2a1c443f518 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1118,23 +1118,24 @@ void mlir::populateVectorToLLVMMatrixConversionPatterns(
 }
 
 namespace {
-struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
+struct LowerVectorToLLVMPass
+    : public OperationPass<LowerVectorToLLVMPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_ConvertVectorToLLVM
 #include "mlir/Conversion/Passes.h.inc"
 
-  void runOnModule() override;
+  void runOnOperation() override;
 };
 } // namespace
 
-void LowerVectorToLLVMPass::runOnModule() {
+void LowerVectorToLLVMPass::runOnOperation() {
   // Perform progressive lowering of operations on slices and
   // all contraction operations. Also applies folding and DCE.
   {
     OwningRewritePatternList patterns;
     populateVectorSlicesLoweringPatterns(patterns, &getContext());
     populateVectorContractLoweringPatterns(patterns, &getContext());
-    applyPatternsGreedily(getModule(), patterns);
+    applyPatternsGreedily(getOperation(), patterns);
   }
 
   // Convert to the LLVM IR dialect.
@@ -1148,8 +1149,8 @@ void LowerVectorToLLVMPass::runOnModule() {
   LLVMConversionTarget target(getContext());
   target.addDynamicallyLegalOp<FuncOp>(
       [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
-  if (failed(
-          applyPartialConversion(getModule(), target, patterns, &converter))) {
+  if (failed(applyPartialConversion(getOperation(), target, patterns,
+                                    &converter))) {
     signalPassFailure();
   }
 }

diff  --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 2eadf87f038a..daf9169d242c 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -214,16 +214,17 @@ namespace {
 /// The gpu.modules are intended to be compiled to a cubin blob independently in
 /// a separate pass. The external functions can then be annotated with the
 /// symbol of the cubin accessor function.
-class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
+class GpuKernelOutliningPass
+    : public OperationPass<GpuKernelOutliningPass, ModuleOp> {
 public:
 /// Include the generated pass utilities.
 #define GEN_PASS_GpuKernelOutlining
 #include "mlir/Dialect/GPU/Passes.h.inc"
 
-  void runOnModule() override {
-    SymbolTable symbolTable(getModule());
+  void runOnOperation() override {
+    SymbolTable symbolTable(getOperation());
     bool modified = false;
-    for (auto func : getModule().getOps<FuncOp>()) {
+    for (auto func : getOperation().getOps<FuncOp>()) {
       // Insert just after the function.
       Block::iterator insertPt(func.getOperation()->getNextNode());
       auto funcWalkResult = func.walk([&](gpu::LaunchOp op) {
@@ -255,8 +256,8 @@ class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
     // If any new module was inserted in this module, annotate this module as
     // a container module.
     if (modified)
-      getModule().setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
-                          UnitAttr::get(&getContext()));
+      getOperation().setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
+                             UnitAttr::get(&getContext()));
   }
 
 private:
@@ -267,7 +268,7 @@ class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
     // a SymbolTable by the caller. SymbolTable needs to be refactored to
     // prevent manual building of Ops with symbols in code using SymbolTables
     // and then this needs to use the OpBuilder.
-    auto context = getModule().getContext();
+    auto context = getOperation().getContext();
     Builder builder(context);
     OperationState state(kernelFunc.getLoc(),
                          gpu::GPUModuleOp::getOperationName());

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
index 79ed81956f08..e4622741536e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp
@@ -80,14 +80,14 @@ static void populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns,
 
 namespace {
 class DecorateSPIRVCompositeTypeLayoutPass
-    : public ModulePass<DecorateSPIRVCompositeTypeLayoutPass> {
+    : public OperationPass<DecorateSPIRVCompositeTypeLayoutPass, ModuleOp> {
 private:
-  void runOnModule() override;
+  void runOnOperation() override;
 };
 } // namespace
 
-void DecorateSPIRVCompositeTypeLayoutPass::runOnModule() {
-  auto module = getModule();
+void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
+  auto module = getOperation();
   OwningRewritePatternList patterns;
   populateSPIRVLayoutInfoPatterns(patterns, module.getContext());
   ConversionTarget target(*(module.getContext()));

diff  --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp
index b7832f580dd4..2b519d697020 100644
--- a/mlir/lib/Transforms/OpStats.cpp
+++ b/mlir/lib/Transforms/OpStats.cpp
@@ -18,7 +18,7 @@
 using namespace mlir;
 
 namespace {
-struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
+struct PrintOpStatsPass : public OperationPass<PrintOpStatsPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_PrintOpStats
 #include "mlir/Transforms/Passes.h.inc"
@@ -26,7 +26,7 @@ struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
   explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {}
 
   // Prints the resultant operation statistics post iterating over the module.
-  void runOnModule() override;
+  void runOnOperation() override;
 
   // Print summary of op stats.
   void printSummary();
@@ -37,11 +37,11 @@ struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
 };
 } // namespace
 
-void PrintOpStatsPass::runOnModule() {
+void PrintOpStatsPass::runOnOperation() {
   opCount.clear();
 
   // Compute the operation statistics for each function in the module.
-  for (auto &op : getModule())
+  for (auto &op : getOperation())
     op.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
   printSummary();
 }

diff  --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index fcaff9a0b069..c5d921db059e 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -100,7 +100,7 @@ namespace {
 // PrintOpPass is simple pass to write graph per function.
 // Note: this is a module pass only to avoid interleaving on the same ostream
 // due to multi-threading over functions.
-struct PrintOpPass : public ModulePass<PrintOpPass> {
+struct PrintOpPass : public OperationPass<PrintOpPass, ModuleOp> {
 /// Include the generated pass utilities.
 #define GEN_PASS_PrintOpGraph
 #include "mlir/Transforms/Passes.h.inc"
@@ -140,7 +140,7 @@ struct PrintOpPass : public ModulePass<PrintOpPass> {
     }
   }
 
-  void runOnModule() override { processModule(getModule()); }
+  void runOnOperation() override { processModule(getOperation()); }
 
 private:
   raw_ostream &os;

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e6cc52d29722..6ccfa04a8194 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -398,13 +398,13 @@ struct TestTypeConverter : public TypeConverter {
 };
 
 struct TestLegalizePatternDriver
-    : public ModulePass<TestLegalizePatternDriver> {
+    : public OperationPass<TestLegalizePatternDriver, ModuleOp> {
   /// The mode of conversion to use with the driver.
   enum class ConversionMode { Analysis, Full, Partial };
 
   TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
 
-  void runOnModule() override {
+  void runOnOperation() override {
     TestTypeConverter converter;
     mlir::OwningRewritePatternList patterns;
     populateWithGenerated(&getContext(), &patterns);
@@ -450,7 +450,8 @@ struct TestLegalizePatternDriver
 
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
-      (void)applyPartialConversion(getModule(), target, patterns, &converter);
+      (void)applyPartialConversion(getOperation(), target, patterns,
+                                   &converter);
       return;
     }
 
@@ -461,7 +462,7 @@ struct TestLegalizePatternDriver
         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
       });
 
-      (void)applyFullConversion(getModule(), target, patterns, &converter);
+      (void)applyFullConversion(getOperation(), target, patterns, &converter);
       return;
     }
 
@@ -470,7 +471,7 @@ struct TestLegalizePatternDriver
 
     // Analyze the convertible operations.
     DenseSet<Operation *> legalizedOps;
-    if (failed(applyAnalysisConversion(getModule(), target, patterns,
+    if (failed(applyAnalysisConversion(getOperation(), target, patterns,
                                        legalizedOps, &converter)))
       return signalPassFailure();
 

diff  --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp
index 0e885c555e38..c1b90397ec44 100644
--- a/mlir/test/lib/IR/TestFunc.cpp
+++ b/mlir/test/lib/IR/TestFunc.cpp
@@ -13,9 +13,9 @@ using namespace mlir;
 
 namespace {
 /// This is a test pass for verifying FuncOp's eraseArgument method.
-struct TestFuncEraseArg : public ModulePass<TestFuncEraseArg> {
-  void runOnModule() override {
-    auto module = getModule();
+struct TestFuncEraseArg : public OperationPass<TestFuncEraseArg, ModuleOp> {
+  void runOnOperation() override {
+    auto module = getOperation();
 
     for (FuncOp func : module.getOps<FuncOp>()) {
       SmallVector<unsigned, 4> indicesToErase;
@@ -36,9 +36,9 @@ struct TestFuncEraseArg : public ModulePass<TestFuncEraseArg> {
 };
 
 /// This is a test pass for verifying FuncOp's setType method.
-struct TestFuncSetType : public ModulePass<TestFuncSetType> {
-  void runOnModule() override {
-    auto module = getModule();
+struct TestFuncSetType : public OperationPass<TestFuncSetType, ModuleOp> {
+  void runOnOperation() override {
+    auto module = getOperation();
     SymbolTable symbolTable(module);
 
     for (FuncOp func : module.getOps<FuncOp>()) {

diff  --git a/mlir/test/lib/IR/TestSideEffects.cpp b/mlir/test/lib/IR/TestSideEffects.cpp
index 9f52c42e4953..a99348537e25 100644
--- a/mlir/test/lib/IR/TestSideEffects.cpp
+++ b/mlir/test/lib/IR/TestSideEffects.cpp
@@ -12,9 +12,9 @@
 using namespace mlir;
 
 namespace {
-struct SideEffectsPass : public ModulePass<SideEffectsPass> {
-  void runOnModule() override {
-    auto module = getModule();
+struct SideEffectsPass : public OperationPass<SideEffectsPass, ModuleOp> {
+  void runOnOperation() override {
+    auto module = getOperation();
 
     // Walk operations detecting side effects.
     SmallVector<MemoryEffects::EffectInstance, 8> effects;

diff  --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
index 6082cdcbe72b..c39615ef1352 100644
--- a/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -15,7 +15,7 @@ using namespace mlir;
 namespace {
 /// This is a symbol test pass that tests the symbol uselist functionality
 /// provided by the symbol table along with erasing from the symbol table.
-struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
+struct SymbolUsesPass : public OperationPass<SymbolUsesPass, ModuleOp> {
   WalkResult operateOnSymbol(Operation *symbol, ModuleOp module,
                              SmallVectorImpl<FuncOp> &deadFunctions) {
     // Test computing uses on a non symboltable op.
@@ -59,8 +59,8 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
     return WalkResult::advance();
   }
 
-  void runOnModule() override {
-    auto module = getModule();
+  void runOnOperation() override {
+    auto module = getOperation();
 
     // Walk nested symbols.
     SmallVector<FuncOp, 4> deadFunctions;
@@ -86,9 +86,10 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
 
 /// This is a symbol test pass that tests the symbol use replacement
 /// functionality provided by the symbol table.
-struct SymbolReplacementPass : public ModulePass<SymbolReplacementPass> {
-  void runOnModule() override {
-    auto module = getModule();
+struct SymbolReplacementPass
+    : public OperationPass<SymbolReplacementPass, ModuleOp> {
+  void runOnOperation() override {
+    auto module = getOperation();
 
     // Walk nested functions and modules.
     module.getBodyRegion().walk([&](Operation *nestedOp) {

diff  --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 95bef9b878e2..be8a7479200c 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -13,8 +13,8 @@
 using namespace mlir;
 
 namespace {
-struct TestModulePass : public ModulePass<TestModulePass> {
-  void runOnModule() final {}
+struct TestModulePass : public OperationPass<TestModulePass, ModuleOp> {
+  void runOnOperation() final {}
 };
 struct TestFunctionPass : public FunctionPass<TestFunctionPass> {
   void runOnFunction() final {}

diff  --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp
index 508f70887350..6455dab70f45 100644
--- a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp
+++ b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp
@@ -18,11 +18,11 @@ using namespace mlir;
 
 namespace {
 struct TestAllReduceLoweringPass
-    : public ModulePass<TestAllReduceLoweringPass> {
-  void runOnModule() override {
+    : public OperationPass<TestAllReduceLoweringPass, ModuleOp> {
+  void runOnOperation() override {
     OwningRewritePatternList patterns;
     populateGpuRewritePatterns(&getContext(), patterns);
-    applyPatternsGreedily(getModule(), patterns);
+    applyPatternsGreedily(getOperation(), patterns);
   }
 };
 } // namespace

diff  --git a/mlir/test/lib/Transforms/TestCallGraph.cpp b/mlir/test/lib/Transforms/TestCallGraph.cpp
index 89c25da9e8ed..a181d645f2af 100644
--- a/mlir/test/lib/Transforms/TestCallGraph.cpp
+++ b/mlir/test/lib/Transforms/TestCallGraph.cpp
@@ -17,9 +17,9 @@
 using namespace mlir;
 
 namespace {
-struct TestCallGraphPass : public ModulePass<TestCallGraphPass> {
-  void runOnModule() {
-    llvm::errs() << "Testing : " << getModule().getAttr("test.name") << "\n";
+struct TestCallGraphPass : public OperationPass<TestCallGraphPass, ModuleOp> {
+  void runOnOperation() override {
+    llvm::errs() << "Testing : " << getOperation().getAttr("test.name") << "\n";
     getAnalysis<CallGraph>().print(llvm::errs());
   }
 };

diff  --git a/mlir/test/lib/Transforms/TestOpaqueLoc.cpp b/mlir/test/lib/Transforms/TestOpaqueLoc.cpp
index baae5297306d..47152c459805 100644
--- a/mlir/test/lib/Transforms/TestOpaqueLoc.cpp
+++ b/mlir/test/lib/Transforms/TestOpaqueLoc.cpp
@@ -17,7 +17,7 @@ namespace {
 /// It also takes all operations that are not function operations or
 /// terminators and clones them with opaque locations which store the initial
 /// locations.
-struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
+struct TestOpaqueLoc : public OperationPass<TestOpaqueLoc, ModuleOp> {
 
   /// A simple structure which is used for testing as an underlying location in
   /// OpaqueLoc.
@@ -29,11 +29,11 @@ struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
     int id;
   };
 
-  void runOnModule() override {
+  void runOnOperation() override {
     std::vector<std::unique_ptr<MyLocation>> myLocs;
     int last_it = 0;
 
-    getModule().walk([&](Operation *op) {
+    getOperation().walk([&](Operation *op) {
       myLocs.push_back(std::make_unique<MyLocation>(last_it++));
 
       Location loc = op->getLoc();
@@ -74,7 +74,7 @@ struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
       os.flush();
     });
 
-    getModule().walk([&](Operation *op) { op->emitOpError(); });
+    getOperation().walk([&](Operation *op) { op->emitOpError(); });
   }
 };
 


        


More information about the Mlir-commits mailing list