[Mlir-commits] [mlir] f230d91 - [mlir][spirv] Turn various passes to plain OperationPass

Lei Zhang llvmlistbot at llvm.org
Wed Aug 10 10:50:15 PDT 2022


Author: jackalcooper
Date: 2022-08-10T13:50:07-04:00
New Revision: f230d91592a17176c626b682880e6d6f49862475

URL: https://github.com/llvm/llvm-project/commit/f230d91592a17176c626b682880e6d6f49862475
DIFF: https://github.com/llvm/llvm-project/commit/f230d91592a17176c626b682880e6d6f49862475.diff

LOG: [mlir][spirv] Turn various passes to plain OperationPass

Made passes converting ops from other dialects to spirv OperationPass,
so that downstream compiler could put them in a proper nested pass
manager to lower device code only.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D131591

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h
    mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h
    mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h
    mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h
    mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h
    mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h
    mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h
    mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
    mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
    mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
    mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
    mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h b/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h
index 0e426082e57c3..dfc94a8cf232f 100644
--- a/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h
+++ b/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_CONVERSION_ARITHMETICTOSPIRV_ARITHMETICTOSPIRV_H
 #define MLIR_CONVERSION_ARITHMETICTOSPIRV_ARITHMETICTOSPIRV_H
 
+#include "mlir/Pass/Pass.h"
 #include <memory>
 
 namespace mlir {
@@ -21,7 +22,7 @@ namespace arith {
 void populateArithmeticToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                        RewritePatternSet &patterns);
 
-std::unique_ptr<Pass> createConvertArithmeticToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertArithmeticToSPIRVPass();
 } // namespace arith
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h
index efce9d0de7552..9259626aef22c 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h
@@ -19,7 +19,7 @@ namespace mlir {
 class ModuleOp;
 
 /// Creates a pass to convert ControlFlow ops to SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertControlFlowToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertControlFlowToSPIRVPass();
 
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h
index 14438d9ad452e..8329a54a4178c 100644
--- a/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h
@@ -19,7 +19,7 @@ namespace mlir {
 class ModuleOp;
 
 /// Creates a pass to convert Func ops to SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertFuncToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertFuncToSPIRVPass();
 
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h
index 83d12ff8fedcf..0281315d3c0cf 100644
--- a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h
@@ -19,7 +19,7 @@ namespace mlir {
 class ModuleOp;
 
 /// Creates a pass to convert Math ops to SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertMathToSPIRVPass();
 
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
index 9f81d3376c5bd..bd449ea264d23 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
@@ -24,7 +24,7 @@ class ModuleOp;
 std::unique_ptr<OperationPass<>> createMapMemRefStorageClassPass();
 
 /// Creates a pass to convert MemRef ops to SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertMemRefToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertMemRefToSPIRVPass();
 
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 39ca0debc2ec3..bacbaeb511c82 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -114,7 +114,7 @@ def ConvertArithmeticToLLVM : Pass<"convert-arith-to-llvm"> {
 // ArithmeticToSPIRV
 //===----------------------------------------------------------------------===//
 
-def ConvertArithmeticToSPIRV : Pass<"convert-arith-to-spirv", "ModuleOp"> {
+def ConvertArithmeticToSPIRV : Pass<"convert-arith-to-spirv"> {
   let summary = "Convert Arithmetic dialect to SPIR-V dialect";
   let constructor = "mlir::arith::createConvertArithmeticToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];
@@ -250,7 +250,7 @@ def ConvertControlFlowToLLVM : Pass<"convert-cf-to-llvm", "ModuleOp"> {
 // ControlFlowToSPIRV
 //===----------------------------------------------------------------------===//
 
-def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv", "ModuleOp"> {
+def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv"> {
   let summary = "Convert ControlFlow dialect to SPIR-V dialect";
   let constructor = "mlir::createConvertControlFlowToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];
@@ -311,7 +311,7 @@ def ConvertFuncToLLVM : Pass<"convert-func-to-llvm", "ModuleOp"> {
 // FuncToSPIRV
 //===----------------------------------------------------------------------===//
 
-def ConvertFuncToSPIRV : Pass<"convert-func-to-spirv", "ModuleOp"> {
+def ConvertFuncToSPIRV : Pass<"convert-func-to-spirv"> {
   let summary = "Convert Func dialect to SPIR-V dialect";
   let constructor = "mlir::createConvertFuncToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];
@@ -505,7 +505,7 @@ def ConvertMathToLLVM : Pass<"convert-math-to-llvm"> {
 // MathToSPIRV
 //===----------------------------------------------------------------------===//
 
-def ConvertMathToSPIRV : Pass<"convert-math-to-spirv", "ModuleOp"> {
+def ConvertMathToSPIRV : Pass<"convert-math-to-spirv"> {
   let summary = "Convert Math dialect to SPIR-V dialect";
   let constructor = "mlir::createConvertMathToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];
@@ -548,7 +548,7 @@ def MapMemRefStorageClass : Pass<"map-memref-spirv-storage-class"> {
   ];
 }
 
-def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv", "ModuleOp"> {
+def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
   let summary = "Convert MemRef dialect to SPIR-V dialect";
   let constructor = "mlir::createConvertMemRefToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];
@@ -668,7 +668,7 @@ def ConvertSCFToOpenMP : Pass<"convert-scf-to-openmp", "ModuleOp"> {
 // SCFToSPIRV
 //===----------------------------------------------------------------------===//
 
-def SCFToSPIRV : Pass<"convert-scf-to-spirv", "ModuleOp"> {
+def SCFToSPIRV : Pass<"convert-scf-to-spirv"> {
   let summary = "Convert SCF dialect to SPIR-V dialect.";
   let description = [{
     This pass converts SCF ops into SPIR-V structured control flow ops.
@@ -764,7 +764,7 @@ def ConvertTensorToLinalg : Pass<"convert-tensor-to-linalg", "ModuleOp"> {
 // TensorToSPIRV
 //===----------------------------------------------------------------------===//
 
-def ConvertTensorToSPIRV : Pass<"convert-tensor-to-spirv", "ModuleOp"> {
+def ConvertTensorToSPIRV : Pass<"convert-tensor-to-spirv"> {
   let summary = "Convert Tensor dialect to SPIR-V dialect";
   let constructor = "mlir::createConvertTensorToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];
@@ -961,7 +961,7 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
 // VectorToSPIRV
 //===----------------------------------------------------------------------===//
 
-def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv", "ModuleOp"> {
+def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> {
   let summary = "Convert Vector dialect to SPIR-V dialect";
   let constructor = "mlir::createConvertVectorToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];

diff  --git a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h
index af3dbf1ed3504..4299537981db3 100644
--- a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h
@@ -19,7 +19,7 @@ namespace mlir {
 class ModuleOp;
 
 /// Creates a pass to convert SCF ops into SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertSCFToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertSCFToSPIRVPass();
 
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h b/mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h
index 93229d07e469f..5f9081f0bb3ad 100644
--- a/mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h
@@ -19,7 +19,7 @@ namespace mlir {
 class ModuleOp;
 
 /// Creates a pass to convert Tensor ops to SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertTensorToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertTensorToSPIRVPass();
 
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h
index ad0971e379ea5..5335221bbd42f 100644
--- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h
@@ -19,7 +19,7 @@ namespace mlir {
 class ModuleOp;
 
 /// Creates a pass to convert Vector Ops to SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertVectorToSPIRVPass();
 
 } // namespace mlir
 

diff  --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 6293e7448a64e..52ab62c85dc56 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -912,8 +912,8 @@ namespace {
 struct ConvertArithmeticToSPIRVPass
     : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
   void runOnOperation() override {
-    auto module = getOperation();
-    auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+    Operation *op = getOperation();
+    auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
     auto target = SPIRVConversionTarget::get(targetAttr);
 
     SPIRVTypeConverter::Options options;
@@ -934,12 +934,13 @@ struct ConvertArithmeticToSPIRVPass
     RewritePatternSet patterns(&getContext());
     arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
 
-    if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+    if (failed(applyPartialConversion(op, *target, std::move(patterns))))
       signalPassFailure();
   }
 };
 } // namespace
 
-std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() {
+std::unique_ptr<OperationPass<>>
+mlir::arith::createConvertArithmeticToSPIRVPass() {
   return std::make_unique<ConvertArithmeticToSPIRVPass>();
 }

diff  --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 1d8c004d76cd8..6cd237ea3e0e1 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -28,9 +28,9 @@ class ConvertControlFlowToSPIRVPass
 
 void ConvertControlFlowToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  ModuleOp module = getOperation();
+  Operation *op = getOperation();
 
-  auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+  auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
 
@@ -41,11 +41,10 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
   RewritePatternSet patterns(context);
   cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
 
-  if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
     return signalPassFailure();
 }
 
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createConvertControlFlowToSPIRVPass() {
+std::unique_ptr<OperationPass<>> mlir::createConvertControlFlowToSPIRVPass() {
   return std::make_unique<ConvertControlFlowToSPIRVPass>();
 }

diff  --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index b9e3492d04c16..d2416feea763c 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -28,9 +28,9 @@ class ConvertFuncToSPIRVPass
 
 void ConvertFuncToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  ModuleOp module = getOperation();
+  Operation *op = getOperation();
 
-  auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+  auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
 
@@ -42,10 +42,10 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
   populateFuncToSPIRVPatterns(typeConverter, patterns);
   populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
 
-  if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
     return signalPassFailure();
 }
 
-std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertFuncToSPIRVPass() {
+std::unique_ptr<OperationPass<>> mlir::createConvertFuncToSPIRVPass() {
   return std::make_unique<ConvertFuncToSPIRVPass>();
 }

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index fb5c89244b549..480c903f83ec1 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -59,7 +59,7 @@ void GPUToSPIRVPass::runOnOperation() {
     gpuModules.push_back(builder.clone(*moduleOp.getOperation()));
   });
 
-  // Map MemRef memory space to SPIR-V sotrage class first if requested.
+  // Map MemRef memory space to SPIR-V storage class first if requested.
   if (mapMemorySpace) {
     std::unique_ptr<ConversionTarget> target =
         spirv::getMemorySpaceToStorageClassTarget(*context);

diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp
index 0817ba7560d95..6ef71d9f27a19 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp
@@ -28,9 +28,9 @@ class ConvertMathToSPIRVPass
 
 void ConvertMathToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  ModuleOp module = getOperation();
+  Operation *op = getOperation();
 
-  auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+  auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
 
@@ -50,10 +50,10 @@ void ConvertMathToSPIRVPass::runOnOperation() {
   RewritePatternSet patterns(context);
   populateMathToSPIRVPatterns(typeConverter, patterns);
 
-  if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
     return signalPassFailure();
 }
 
-std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToSPIRVPass() {
+std::unique_ptr<OperationPass<>> mlir::createConvertMathToSPIRVPass() {
   return std::make_unique<ConvertMathToSPIRVPass>();
 }

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
index cd0d3429c6b58..44fc17bde62fb 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
@@ -28,9 +28,9 @@ class ConvertMemRefToSPIRVPass
 
 void ConvertMemRefToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  ModuleOp module = getOperation();
+  Operation *op = getOperation();
 
-  auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+  auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
 
@@ -52,11 +52,10 @@ void ConvertMemRefToSPIRVPass::runOnOperation() {
   RewritePatternSet patterns(context);
   populateMemRefToSPIRVPatterns(typeConverter, patterns);
 
-  if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
     return signalPassFailure();
 }
 
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createConvertMemRefToSPIRVPass() {
+std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToSPIRVPass() {
   return std::make_unique<ConvertMemRefToSPIRVPass>();
 }

diff  --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
index 02a0d80183dd0..1b22fadc4a348 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
@@ -31,9 +31,9 @@ struct SCFToSPIRVPass : public SCFToSPIRVBase<SCFToSPIRVPass> {
 
 void SCFToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  ModuleOp module = getOperation();
+  Operation *op = getOperation();
 
-  auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+  auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
 
@@ -49,10 +49,10 @@ void SCFToSPIRVPass::runOnOperation() {
   populateMemRefToSPIRVPatterns(typeConverter, patterns);
   populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
 
-  if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
     return signalPassFailure();
 }
 
-std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSCFToSPIRVPass() {
+std::unique_ptr<OperationPass<>> mlir::createConvertSCFToSPIRVPass() {
   return std::make_unique<SCFToSPIRVPass>();
 }

diff  --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index 616e8a54b5be9..3a8ccc8a7b85a 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -26,9 +26,9 @@ class ConvertTensorToSPIRVPass
     : public ConvertTensorToSPIRVBase<ConvertTensorToSPIRVPass> {
   void runOnOperation() override {
     MLIRContext *context = &getContext();
-    ModuleOp module = getOperation();
+    Operation *op = getOperation();
 
-    auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+    auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
     std::unique_ptr<ConversionTarget> target =
         SPIRVConversionTarget::get(targetAttr);
 
@@ -43,13 +43,12 @@ class ConvertTensorToSPIRVPass
                                   patterns);
     populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
 
-    if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+    if (failed(applyPartialConversion(op, *target, std::move(patterns))))
       return signalPassFailure();
   }
 };
 } // namespace
 
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createConvertTensorToSPIRVPass() {
+std::unique_ptr<OperationPass<>> mlir::createConvertTensorToSPIRVPass() {
   return std::make_unique<ConvertTensorToSPIRVPass>();
 }

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index 7391defe78589..d3585cad4897a 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -30,9 +30,9 @@ struct ConvertVectorToSPIRVPass
 
 void ConvertVectorToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  ModuleOp module = getOperation();
+  Operation *op = getOperation();
 
-  auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+  auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
 
@@ -52,11 +52,10 @@ void ConvertVectorToSPIRVPass::runOnOperation() {
   RewritePatternSet patterns(context);
   populateVectorToSPIRVPatterns(typeConverter, patterns);
 
-  if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
     return signalPassFailure();
 }
 
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createConvertVectorToSPIRVPass() {
+std::unique_ptr<OperationPass<>> mlir::createConvertVectorToSPIRVPass() {
   return std::make_unique<ConvertVectorToSPIRVPass>();
 }


        


More information about the Mlir-commits mailing list