[Mlir-commits] [mlir] dfaebd3 - [mlir][spirv] NFC: move conversion options out of the type converter

Lei Zhang llvmlistbot at llvm.org
Fri Sep 9 13:15:46 PDT 2022


Author: Lei Zhang
Date: 2022-09-09T16:15:27-04:00
New Revision: dfaebd3d7b99a1812888014337e02a568dd59169

URL: https://github.com/llvm/llvm-project/commit/dfaebd3d7b99a1812888014337e02a568dd59169
DIFF: https://github.com/llvm/llvm-project/commit/dfaebd3d7b99a1812888014337e02a568dd59169.diff

LOG: [mlir][spirv] NFC: move conversion options out of the type converter

This is a step for adding more options not directly related to type
conversion. Also with this we can now avoid the explicit constructor.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D133596

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
    mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
    mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
    mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
    mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 11928c95e21c1..81e701fbd2d19 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -25,6 +25,31 @@ namespace mlir {
 // Type Converter
 //===----------------------------------------------------------------------===//
 
+struct SPIRVConversionOptions {
+  /// Whether to emulate non-32-bit scalar types with 32-bit scalar types if
+  /// no native support.
+  ///
+  /// Non-32-bit scalar types require special hardware support that may not
+  /// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar
+  /// types require special capabilities or extensions. This option controls
+  /// whether to use 32-bit types to emulate, if a scalar type of a certain
+  /// bitwidth is not supported in the target environment. This requires the
+  /// runtime to also feed in data with a matched bitwidth and layout for
+  /// interface types. The runtime can do that by inspecting the SPIR-V
+  /// module.
+  ///
+  /// If the original scalar type has less than 32-bit, a multiple of its
+  /// values will be packed into one 32-bit value to be memory efficient.
+  bool emulateNon32BitScalarTypes{true};
+
+  /// Use 64-bit integers to convert index types.
+  bool use64bitIndex{false};
+
+  /// The number of bits to store a boolean value. It is eight bits by
+  /// default.
+  unsigned boolNumBits{8};
+};
+
 /// Type conversion from builtin types to SPIR-V types for shader interface.
 ///
 /// For memref types, this converter additionally performs type wrapping to
@@ -32,39 +57,8 @@ namespace mlir {
 /// pointers to structs.
 class SPIRVTypeConverter : public TypeConverter {
 public:
-  struct Options {
-    /// Whether to emulate non-32-bit scalar types with 32-bit scalar types if
-    /// no native support.
-    ///
-    /// Non-32-bit scalar types require special hardware support that may not
-    /// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar
-    /// types require special capabilities or extensions. This option controls
-    /// whether to use 32-bit types to emulate, if a scalar type of a certain
-    /// bitwidth is not supported in the target environment. This requires the
-    /// runtime to also feed in data with a matched bitwidth and layout for
-    /// interface types. The runtime can do that by inspecting the SPIR-V
-    /// module.
-    ///
-    /// If the original scalar type has less than 32-bit, a multiple of its
-    /// values will be packed into one 32-bit value to be memory efficient.
-    bool emulateNon32BitScalarTypes{true};
-
-    /// Use 64-bit integers to convert index types.
-    bool use64bitIndex{false};
-
-    /// The number of bits to store a boolean value. It is eight bits by
-    /// default.
-    unsigned boolNumBits{8};
-
-    // Note: we need this instead of inline initializers because of
-    // https://bugs.llvm.org/show_bug.cgi?id=36684
-    Options()
-
-    {}
-  };
-
   explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
-                              Options options = {});
+                              const SPIRVConversionOptions &options = {});
 
   /// Gets the SPIR-V correspondence for the standard index type.
   Type getIndexType() const;
@@ -72,11 +66,11 @@ class SPIRVTypeConverter : public TypeConverter {
   const spirv::TargetEnv &getTargetEnv() const { return targetEnv; }
 
   /// Returns the options controlling the SPIR-V type converter.
-  const Options &getOptions() const { return options; }
+  const SPIRVConversionOptions &getOptions() const { return options; }
 
 private:
   spirv::TargetEnv targetEnv;
-  Options options;
+  SPIRVConversionOptions options;
 
   MLIRContext *getContext() const;
 };

diff  --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
index b65d13c13b943..2af5fc9318fd2 100644
--- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp
@@ -959,7 +959,7 @@ struct ConvertArithmeticToSPIRVPass
     auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
     auto target = SPIRVConversionTarget::get(targetAttr);
 
-    SPIRVTypeConverter::Options options;
+    SPIRVConversionOptions options;
     options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
     SPIRVTypeConverter typeConverter(targetAttr, options);
 

diff  --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 69b2a874d7b06..0d1e8b8079465 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -40,7 +40,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
 
-  SPIRVTypeConverter::Options options;
+  SPIRVConversionOptions options;
   options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
   SPIRVTypeConverter typeConverter(targetAttr, options);
 

diff  --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index 303692ad8bcbc..a82ba5dd12a5d 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -39,7 +39,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
 
-  SPIRVTypeConverter::Options options;
+  SPIRVConversionOptions options;
   options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
   SPIRVTypeConverter typeConverter(targetAttr, options);
 

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
index 13b2b45d34c6f..b12311bc5624c 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp
@@ -39,7 +39,7 @@ void ConvertMemRefToSPIRVPass::runOnOperation() {
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
 
-  SPIRVTypeConverter::Options options;
+  SPIRVConversionOptions options;
   options.boolNumBits = this->boolNumBits;
   SPIRVTypeConverter typeConverter(targetAttr, options);
 

diff  --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index 4c0dfb62939f0..b14ce1c36b3a7 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -37,7 +37,7 @@ class ConvertTensorToSPIRVPass
     std::unique_ptr<ConversionTarget> target =
         SPIRVConversionTarget::get(targetAttr);
 
-    SPIRVTypeConverter::Options options;
+    SPIRVConversionOptions options;
     options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
     SPIRVTypeConverter typeConverter(targetAttr, options);
 

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index fc2f8c98f8821..8d0bde66ebdf9 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -124,8 +124,8 @@ MLIRContext *SPIRVTypeConverter::getContext() const {
 
 // TODO: This is a utility function that should probably be exposed by the
 // SPIR-V dialect. Keeping it local till the use case arises.
-static Optional<int64_t>
-getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) {
+static Optional<int64_t> getTypeNumBytes(const SPIRVConversionOptions &options,
+                                         Type type) {
   if (type.isa<spirv::ScalarType>()) {
     auto bitWidth = type.getIntOrFloatBitWidth();
     // According to the SPIR-V spec:
@@ -199,7 +199,7 @@ getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) {
 
 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
 static Type convertScalarType(const spirv::TargetEnv &targetEnv,
-                              const SPIRVTypeConverter::Options &options,
+                              const SPIRVConversionOptions &options,
                               spirv::ScalarType type,
                               Optional<spirv::StorageClass> storageClass = {}) {
   // Get extension and capability requirements for the given type.
@@ -232,7 +232,7 @@ static Type convertScalarType(const spirv::TargetEnv &targetEnv,
 
 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
 static Type convertVectorType(const spirv::TargetEnv &targetEnv,
-                              const SPIRVTypeConverter::Options &options,
+                              const SPIRVConversionOptions &options,
                               VectorType type,
                               Optional<spirv::StorageClass> storageClass = {}) {
   if (type.getRank() <= 1 && type.getNumElements() == 1)
@@ -271,7 +271,7 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv,
 /// constant values and use OpCompositeExtract and OpCompositeInsert to
 /// manipulate, like what we do for vectors.
 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
-                              const SPIRVTypeConverter::Options &options,
+                              const SPIRVConversionOptions &options,
                               TensorType type) {
   // TODO: Handle dynamic shapes.
   if (!type.hasStaticShape()) {
@@ -310,7 +310,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
 }
 
 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
-                                  const SPIRVTypeConverter::Options &options,
+                                  const SPIRVConversionOptions &options,
                                   MemRefType type,
                                   spirv::StorageClass storageClass) {
   unsigned numBoolBits = options.boolNumBits;
@@ -349,7 +349,7 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
 }
 
 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
-                              const SPIRVTypeConverter::Options &options,
+                              const SPIRVConversionOptions &options,
                               MemRefType type) {
   auto attr = type.getMemorySpace().dyn_cast_or_null<spirv::StorageClassAttr>();
   if (!attr) {
@@ -414,7 +414,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
 }
 
 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
-                                       Options options)
+                                       const SPIRVConversionOptions &options)
     : targetEnv(targetAttr), options(options) {
   // Add conversions. The order matters here: later ones will be tried earlier.
 


        


More information about the Mlir-commits mailing list