[Mlir-commits] [mlir] dfaebd3 - [mlir][spirv] NFC: move conversion options out of the type converter
Lei Zhang
llvmlistbot at llvm.org
Fri Sep 9 13:15:46 PDT 2022
Author: Lei Zhang
Date: 2022-09-09T16:15:27-04:00
New Revision: dfaebd3d7b99a1812888014337e02a568dd59169
URL: https://github.com/llvm/llvm-project/commit/dfaebd3d7b99a1812888014337e02a568dd59169
DIFF: https://github.com/llvm/llvm-project/commit/dfaebd3d7b99a1812888014337e02a568dd59169.diff
LOG: [mlir][spirv] NFC: move conversion options out of the type converter
This is a step for adding more options not directly related to type
conversion. Also with this we can now avoid the explicit constructor.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D133596
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 11928c95e21c1..81e701fbd2d19 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -25,6 +25,31 @@ namespace mlir {
// Type Converter
//===----------------------------------------------------------------------===//
+struct SPIRVConversionOptions {
+ /// Whether to emulate non-32-bit scalar types with 32-bit scalar types if
+ /// no native support.
+ ///
+ /// Non-32-bit scalar types require special hardware support that may not
+ /// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar
+ /// types require special capabilities or extensions. This option controls
+ /// whether to use 32-bit types to emulate, if a scalar type of a certain
+ /// bitwidth is not supported in the target environment. This requires the
+ /// runtime to also feed in data with a matched bitwidth and layout for
+ /// interface types. The runtime can do that by inspecting the SPIR-V
+ /// module.
+ ///
+ /// If the original scalar type has less than 32-bit, a multiple of its
+ /// values will be packed into one 32-bit value to be memory efficient.
+ bool emulateNon32BitScalarTypes{true};
+
+ /// Use 64-bit integers to convert index types.
+ bool use64bitIndex{false};
+
+ /// The number of bits to store a boolean value. It is eight bits by
+ /// default.
+ unsigned boolNumBits{8};
+};
+
/// Type conversion from builtin types to SPIR-V types for shader interface.
///
/// For memref types, this converter additionally performs type wrapping to
@@ -32,39 +57,8 @@ namespace mlir {
/// pointers to structs.
class SPIRVTypeConverter : public TypeConverter {
public:
- struct Options {
- /// Whether to emulate non-32-bit scalar types with 32-bit scalar types if
- /// no native support.
- ///
- /// Non-32-bit scalar types require special hardware support that may not
- /// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar
- /// types require special capabilities or extensions. This option controls
- /// whether to use 32-bit types to emulate, if a scalar type of a certain
- /// bitwidth is not supported in the target environment. This requires the
- /// runtime to also feed in data with a matched bitwidth and layout for
- /// interface types. The runtime can do that by inspecting the SPIR-V
- /// module.
- ///
- /// If the original scalar type has less than 32-bit, a multiple of its
- /// values will be packed into one 32-bit value to be memory efficient.
- bool emulateNon32BitScalarTypes{true};
-
- /// Use 64-bit integers to convert index types.
- bool use64bitIndex{false};
-
- /// The number of bits to store a boolean value. It is eight bits by
- /// default.
- unsigned boolNumBits{8};
-
- // Note: we need this instead of inline initializers because of
- // https://bugs.llvm.org/show_bug.cgi?id=36684
- Options()
-
- {}
- };
-
explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
- Options options = {});
+ const SPIRVConversionOptions &options = {});
/// Gets the SPIR-V correspondence for the standard index type.
Type getIndexType() const;
@@ -72,11 +66,11 @@ class SPIRVTypeConverter : public TypeConverter {
const spirv::TargetEnv &getTargetEnv() const { return targetEnv; }
/// Returns the options controlling the SPIR-V type converter.
- const Options &getOptions() const { return options; }
+ const SPIRVConversionOptions &getOptions() const { return options; }
private:
spirv::TargetEnv targetEnv;
- Options options;
+ SPIRVConversionOptions options;
MLIRContext *getContext() const;
};
diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index b65d13c13b943..2af5fc9318fd2 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -959,7 +959,7 @@ struct ConvertArithmeticToSPIRVPass
auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
auto target = SPIRVConversionTarget::get(targetAttr);
- SPIRVTypeConverter::Options options;
+ SPIRVConversionOptions options;
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 69b2a874d7b06..0d1e8b8079465 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -40,7 +40,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
- SPIRVTypeConverter::Options options;
+ SPIRVConversionOptions options;
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index 303692ad8bcbc..a82ba5dd12a5d 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -39,7 +39,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
- SPIRVTypeConverter::Options options;
+ SPIRVConversionOptions options;
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
index 13b2b45d34c6f..b12311bc5624c 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
@@ -39,7 +39,7 @@ void ConvertMemRefToSPIRVPass::runOnOperation() {
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
- SPIRVTypeConverter::Options options;
+ SPIRVConversionOptions options;
options.boolNumBits = this->boolNumBits;
SPIRVTypeConverter typeConverter(targetAttr, options);
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index 4c0dfb62939f0..b14ce1c36b3a7 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -37,7 +37,7 @@ class ConvertTensorToSPIRVPass
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
- SPIRVTypeConverter::Options options;
+ SPIRVConversionOptions options;
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index fc2f8c98f8821..8d0bde66ebdf9 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -124,8 +124,8 @@ MLIRContext *SPIRVTypeConverter::getContext() const {
// TODO: This is a utility function that should probably be exposed by the
// SPIR-V dialect. Keeping it local till the use case arises.
-static Optional<int64_t>
-getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) {
+static Optional<int64_t> getTypeNumBytes(const SPIRVConversionOptions &options,
+ Type type) {
if (type.isa<spirv::ScalarType>()) {
auto bitWidth = type.getIntOrFloatBitWidth();
// According to the SPIR-V spec:
@@ -199,7 +199,7 @@ getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) {
/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
static Type convertScalarType(const spirv::TargetEnv &targetEnv,
- const SPIRVTypeConverter::Options &options,
+ const SPIRVConversionOptions &options,
spirv::ScalarType type,
Optional<spirv::StorageClass> storageClass = {}) {
// Get extension and capability requirements for the given type.
@@ -232,7 +232,7 @@ static Type convertScalarType(const spirv::TargetEnv &targetEnv,
/// Converts a vector `type` to a suitable type under the given `targetEnv`.
static Type convertVectorType(const spirv::TargetEnv &targetEnv,
- const SPIRVTypeConverter::Options &options,
+ const SPIRVConversionOptions &options,
VectorType type,
Optional<spirv::StorageClass> storageClass = {}) {
if (type.getRank() <= 1 && type.getNumElements() == 1)
@@ -271,7 +271,7 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv,
/// constant values and use OpCompositeExtract and OpCompositeInsert to
/// manipulate, like what we do for vectors.
static Type convertTensorType(const spirv::TargetEnv &targetEnv,
- const SPIRVTypeConverter::Options &options,
+ const SPIRVConversionOptions &options,
TensorType type) {
// TODO: Handle dynamic shapes.
if (!type.hasStaticShape()) {
@@ -310,7 +310,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
}
static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
- const SPIRVTypeConverter::Options &options,
+ const SPIRVConversionOptions &options,
MemRefType type,
spirv::StorageClass storageClass) {
unsigned numBoolBits = options.boolNumBits;
@@ -349,7 +349,7 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
}
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
- const SPIRVTypeConverter::Options &options,
+ const SPIRVConversionOptions &options,
MemRefType type) {
auto attr = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
if (!attr) {
@@ -414,7 +414,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
}
SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
- Options options)
+ const SPIRVConversionOptions &options)
: targetEnv(targetAttr), options(options) {
// Add conversions. The order matters here: later ones will be tried earlier.
More information about the Mlir-commits
mailing list