[Mlir-commits] [mlir] f230d91 - [mlir][spirv] Turn various passes to plain OperationPass
Lei Zhang
llvmlistbot at llvm.org
Wed Aug 10 10:50:15 PDT 2022
Author: jackalcooper
Date: 2022-08-10T13:50:07-04:00
New Revision: f230d91592a17176c626b682880e6d6f49862475
URL: https://github.com/llvm/llvm-project/commit/f230d91592a17176c626b682880e6d6f49862475
DIFF: https://github.com/llvm/llvm-project/commit/f230d91592a17176c626b682880e6d6f49862475.diff
LOG: [mlir][spirv] Turn various passes to plain OperationPass
Made passes converting ops from other dialects to spirv OperationPass,
so that downstream compiler could put them in a proper nested pass
manager to lower device code only.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D131591
Added:
Modified:
mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h
mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h
mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h
mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h
mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h
mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h
mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h
mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h b/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h
index 0e426082e57c3..dfc94a8cf232f 100644
--- a/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h
+++ b/mlir/include/mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h
@@ -9,6 +9,7 @@
#ifndef MLIR_CONVERSION_ARITHMETICTOSPIRV_ARITHMETICTOSPIRV_H
#define MLIR_CONVERSION_ARITHMETICTOSPIRV_ARITHMETICTOSPIRV_H
+#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
@@ -21,7 +22,7 @@ namespace arith {
void populateArithmeticToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
-std::unique_ptr<Pass> createConvertArithmeticToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertArithmeticToSPIRVPass();
} // namespace arith
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h
index efce9d0de7552..9259626aef22c 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h
@@ -19,7 +19,7 @@ namespace mlir {
class ModuleOp;
/// Creates a pass to convert ControlFlow ops to SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertControlFlowToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertControlFlowToSPIRVPass();
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h
index 14438d9ad452e..8329a54a4178c 100644
--- a/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h
@@ -19,7 +19,7 @@ namespace mlir {
class ModuleOp;
/// Creates a pass to convert Func ops to SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertFuncToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertFuncToSPIRVPass();
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h
index 83d12ff8fedcf..0281315d3c0cf 100644
--- a/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h
@@ -19,7 +19,7 @@ namespace mlir {
class ModuleOp;
/// Creates a pass to convert Math ops to SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertMathToSPIRVPass();
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
index 9f81d3376c5bd..bd449ea264d23 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
@@ -24,7 +24,7 @@ class ModuleOp;
std::unique_ptr<OperationPass<>> createMapMemRefStorageClassPass();
/// Creates a pass to convert MemRef ops to SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertMemRefToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertMemRefToSPIRVPass();
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 39ca0debc2ec3..bacbaeb511c82 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -114,7 +114,7 @@ def ConvertArithmeticToLLVM : Pass<"convert-arith-to-llvm"> {
// ArithmeticToSPIRV
//===----------------------------------------------------------------------===//
-def ConvertArithmeticToSPIRV : Pass<"convert-arith-to-spirv", "ModuleOp"> {
+def ConvertArithmeticToSPIRV : Pass<"convert-arith-to-spirv"> {
let summary = "Convert Arithmetic dialect to SPIR-V dialect";
let constructor = "mlir::arith::createConvertArithmeticToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
@@ -250,7 +250,7 @@ def ConvertControlFlowToLLVM : Pass<"convert-cf-to-llvm", "ModuleOp"> {
// ControlFlowToSPIRV
//===----------------------------------------------------------------------===//
-def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv", "ModuleOp"> {
+def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv"> {
let summary = "Convert ControlFlow dialect to SPIR-V dialect";
let constructor = "mlir::createConvertControlFlowToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
@@ -311,7 +311,7 @@ def ConvertFuncToLLVM : Pass<"convert-func-to-llvm", "ModuleOp"> {
// FuncToSPIRV
//===----------------------------------------------------------------------===//
-def ConvertFuncToSPIRV : Pass<"convert-func-to-spirv", "ModuleOp"> {
+def ConvertFuncToSPIRV : Pass<"convert-func-to-spirv"> {
let summary = "Convert Func dialect to SPIR-V dialect";
let constructor = "mlir::createConvertFuncToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
@@ -505,7 +505,7 @@ def ConvertMathToLLVM : Pass<"convert-math-to-llvm"> {
// MathToSPIRV
//===----------------------------------------------------------------------===//
-def ConvertMathToSPIRV : Pass<"convert-math-to-spirv", "ModuleOp"> {
+def ConvertMathToSPIRV : Pass<"convert-math-to-spirv"> {
let summary = "Convert Math dialect to SPIR-V dialect";
let constructor = "mlir::createConvertMathToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
@@ -548,7 +548,7 @@ def MapMemRefStorageClass : Pass<"map-memref-spirv-storage-class"> {
];
}
-def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv", "ModuleOp"> {
+def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
let summary = "Convert MemRef dialect to SPIR-V dialect";
let constructor = "mlir::createConvertMemRefToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
@@ -668,7 +668,7 @@ def ConvertSCFToOpenMP : Pass<"convert-scf-to-openmp", "ModuleOp"> {
// SCFToSPIRV
//===----------------------------------------------------------------------===//
-def SCFToSPIRV : Pass<"convert-scf-to-spirv", "ModuleOp"> {
+def SCFToSPIRV : Pass<"convert-scf-to-spirv"> {
let summary = "Convert SCF dialect to SPIR-V dialect.";
let description = [{
This pass converts SCF ops into SPIR-V structured control flow ops.
@@ -764,7 +764,7 @@ def ConvertTensorToLinalg : Pass<"convert-tensor-to-linalg", "ModuleOp"> {
// TensorToSPIRV
//===----------------------------------------------------------------------===//
-def ConvertTensorToSPIRV : Pass<"convert-tensor-to-spirv", "ModuleOp"> {
+def ConvertTensorToSPIRV : Pass<"convert-tensor-to-spirv"> {
let summary = "Convert Tensor dialect to SPIR-V dialect";
let constructor = "mlir::createConvertTensorToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
@@ -961,7 +961,7 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
// VectorToSPIRV
//===----------------------------------------------------------------------===//
-def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv", "ModuleOp"> {
+def ConvertVectorToSPIRV : Pass<"convert-vector-to-spirv"> {
let summary = "Convert Vector dialect to SPIR-V dialect";
let constructor = "mlir::createConvertVectorToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
diff --git a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h
index af3dbf1ed3504..4299537981db3 100644
--- a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h
@@ -19,7 +19,7 @@ namespace mlir {
class ModuleOp;
/// Creates a pass to convert SCF ops into SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertSCFToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertSCFToSPIRVPass();
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h b/mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h
index 93229d07e469f..5f9081f0bb3ad 100644
--- a/mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h
@@ -19,7 +19,7 @@ namespace mlir {
class ModuleOp;
/// Creates a pass to convert Tensor ops to SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertTensorToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertTensorToSPIRVPass();
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h
index ad0971e379ea5..5335221bbd42f 100644
--- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h
@@ -19,7 +19,7 @@ namespace mlir {
class ModuleOp;
/// Creates a pass to convert Vector Ops to SPIR-V ops.
-std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToSPIRVPass();
+std::unique_ptr<OperationPass<>> createConvertVectorToSPIRVPass();
} // namespace mlir
diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index 6293e7448a64e..52ab62c85dc56 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -912,8 +912,8 @@ namespace {
struct ConvertArithmeticToSPIRVPass
: public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
void runOnOperation() override {
- auto module = getOperation();
- auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+ Operation *op = getOperation();
+ auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
auto target = SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter::Options options;
@@ -934,12 +934,13 @@ struct ConvertArithmeticToSPIRVPass
RewritePatternSet patterns(&getContext());
arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
- if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+ if (failed(applyPartialConversion(op, *target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
-std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() {
+std::unique_ptr<OperationPass<>>
+mlir::arith::createConvertArithmeticToSPIRVPass() {
return std::make_unique<ConvertArithmeticToSPIRVPass>();
}
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 1d8c004d76cd8..6cd237ea3e0e1 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -28,9 +28,9 @@ class ConvertControlFlowToSPIRVPass
void ConvertControlFlowToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
- ModuleOp module = getOperation();
+ Operation *op = getOperation();
- auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+ auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
@@ -41,11 +41,10 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
RewritePatternSet patterns(context);
cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns);
- if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+ if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createConvertControlFlowToSPIRVPass() {
+std::unique_ptr<OperationPass<>> mlir::createConvertControlFlowToSPIRVPass() {
return std::make_unique<ConvertControlFlowToSPIRVPass>();
}
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index b9e3492d04c16..d2416feea763c 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -28,9 +28,9 @@ class ConvertFuncToSPIRVPass
void ConvertFuncToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
- ModuleOp module = getOperation();
+ Operation *op = getOperation();
- auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+ auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
@@ -42,10 +42,10 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
populateFuncToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
- if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+ if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
-std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertFuncToSPIRVPass() {
+std::unique_ptr<OperationPass<>> mlir::createConvertFuncToSPIRVPass() {
return std::make_unique<ConvertFuncToSPIRVPass>();
}
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index fb5c89244b549..480c903f83ec1 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -59,7 +59,7 @@ void GPUToSPIRVPass::runOnOperation() {
gpuModules.push_back(builder.clone(*moduleOp.getOperation()));
});
- // Map MemRef memory space to SPIR-V sotrage class first if requested.
+ // Map MemRef memory space to SPIR-V storage class first if requested.
if (mapMemorySpace) {
std::unique_ptr<ConversionTarget> target =
spirv::getMemorySpaceToStorageClassTarget(*context);
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp
index 0817ba7560d95..6ef71d9f27a19 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRVPass.cpp
@@ -28,9 +28,9 @@ class ConvertMathToSPIRVPass
void ConvertMathToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
- ModuleOp module = getOperation();
+ Operation *op = getOperation();
- auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+ auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
@@ -50,10 +50,10 @@ void ConvertMathToSPIRVPass::runOnOperation() {
RewritePatternSet patterns(context);
populateMathToSPIRVPatterns(typeConverter, patterns);
- if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+ if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
-std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertMathToSPIRVPass() {
+std::unique_ptr<OperationPass<>> mlir::createConvertMathToSPIRVPass() {
return std::make_unique<ConvertMathToSPIRVPass>();
}
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
index cd0d3429c6b58..44fc17bde62fb 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
@@ -28,9 +28,9 @@ class ConvertMemRefToSPIRVPass
void ConvertMemRefToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
- ModuleOp module = getOperation();
+ Operation *op = getOperation();
- auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+ auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
@@ -52,11 +52,10 @@ void ConvertMemRefToSPIRVPass::runOnOperation() {
RewritePatternSet patterns(context);
populateMemRefToSPIRVPatterns(typeConverter, patterns);
- if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+ if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createConvertMemRefToSPIRVPass() {
+std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToSPIRVPass() {
return std::make_unique<ConvertMemRefToSPIRVPass>();
}
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
index 02a0d80183dd0..1b22fadc4a348 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
@@ -31,9 +31,9 @@ struct SCFToSPIRVPass : public SCFToSPIRVBase<SCFToSPIRVPass> {
void SCFToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
- ModuleOp module = getOperation();
+ Operation *op = getOperation();
- auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+ auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
@@ -49,10 +49,10 @@ void SCFToSPIRVPass::runOnOperation() {
populateMemRefToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
- if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+ if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
-std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSCFToSPIRVPass() {
+std::unique_ptr<OperationPass<>> mlir::createConvertSCFToSPIRVPass() {
return std::make_unique<SCFToSPIRVPass>();
}
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index 616e8a54b5be9..3a8ccc8a7b85a 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -26,9 +26,9 @@ class ConvertTensorToSPIRVPass
: public ConvertTensorToSPIRVBase<ConvertTensorToSPIRVPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
- ModuleOp module = getOperation();
+ Operation *op = getOperation();
- auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+ auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
@@ -43,13 +43,12 @@ class ConvertTensorToSPIRVPass
patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
- if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+ if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createConvertTensorToSPIRVPass() {
+std::unique_ptr<OperationPass<>> mlir::createConvertTensorToSPIRVPass() {
return std::make_unique<ConvertTensorToSPIRVPass>();
}
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index 7391defe78589..d3585cad4897a 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -30,9 +30,9 @@ struct ConvertVectorToSPIRVPass
void ConvertVectorToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
- ModuleOp module = getOperation();
+ Operation *op = getOperation();
- auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+ auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
@@ -52,11 +52,10 @@ void ConvertVectorToSPIRVPass::runOnOperation() {
RewritePatternSet patterns(context);
populateVectorToSPIRVPatterns(typeConverter, patterns);
- if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+ if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createConvertVectorToSPIRVPass() {
+std::unique_ptr<OperationPass<>> mlir::createConvertVectorToSPIRVPass() {
return std::make_unique<ConvertVectorToSPIRVPass>();
}
More information about the Mlir-commits
mailing list