[Mlir-commits] [mlir] c064545 - [mlir][spirv] Do not truncate i/f64 -> i/f32 in SPIRVConversion

Jakub Kuderski llvmlistbot at llvm.org
Fri Nov 4 12:11:43 PDT 2022

Author: Jakub Kuderski
Date: 2022-11-04T15:10:28-04:00
New Revision: c064545403917bedd450e07209e7870f1773f90f

URL: https://github.com/llvm/llvm-project/commit/c064545403917bedd450e07209e7870f1773f90f
DIFF: https://github.com/llvm/llvm-project/commit/c064545403917bedd450e07209e7870f1773f90f.diff

LOG: [mlir][spirv] Do not truncate i/f64 -> i/f32 in SPIRVConversion

This truncation can be unexpected and break program behavior.
Dedicated emulation passes should be used instead.

Also rename pass options to "emulate-lt-32-bit-scalar-types".

Fixes: https://github.com/llvm/llvm-project/issues/57917

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D137115




diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 66ac9eedf1bfb..cef82f1e29ff1 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -118,10 +118,10 @@ def ConvertArithToSPIRV : Pass<"convert-arith-to-spirv"> {
   let constructor = "mlir::arith::createConvertArithToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];
   let options = [
-    Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types",
+    Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
            "bool", /*default=*/"true",
-           "Emulate non-32-bit scalar types with 32-bit ones if "
-           "missing native support">,
+           "Emulate narrower scalar types with 32-bit ones if not supported by "
+           "the target">,
     Option<"enableFastMath", "enable-fast-math",
            "bool", /*default=*/"false",
            "Enable fast math mode (assuming no NaN and infinity for floating "
@@ -259,10 +259,10 @@ def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv"> {
   let constructor = "mlir::createConvertControlFlowToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];
   let options = [
-    Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types",
+    Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
            "bool", /*default=*/"true",
-           "Emulate non-32-bit scalar types with 32-bit ones if "
-           "missing native support">
+           "Emulate narrower scalar types with 32-bit ones if not supported by"
+           " the target">
@@ -320,10 +320,10 @@ def ConvertFuncToSPIRV : Pass<"convert-func-to-spirv"> {
   let constructor = "mlir::createConvertFuncToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];
   let options = [
-    Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types",
+    Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
            "bool", /*default=*/"true",
-           "Emulate non-32-bit scalar types with 32-bit ones if "
-           "missing native support">
+           "Emulate narrower scalar types with 32-bit ones if not supported by"
+           " the target">
@@ -815,10 +815,10 @@ def ConvertTensorToSPIRV : Pass<"convert-tensor-to-spirv"> {
   let constructor = "mlir::createConvertTensorToSPIRVPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];
   let options = [
-    Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types",
+    Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
            "bool", /*default=*/"true",
-           "Emulate non-32-bit scalar types with 32-bit ones if "
-           "missing native support">
+           "Emulate narrower scalar types with 32-bit ones if not supported by"
+           " the target">

diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 9b480f6cc9e3a..7d362526cc22f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -30,21 +30,21 @@ struct SPIRVConversionOptions {
   /// The number of bits to store a boolean value.
   unsigned boolNumBits{8};
-  /// Whether to emulate non-32-bit scalar types with 32-bit scalar types if
-  /// no native support.
+  /// Whether to emulate narrower scalar types with 32-bit scalar types if not
+  /// supported by the target.
   /// Non-32-bit scalar types require special hardware support that may not
   /// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar
   /// types require special capabilities or extensions. This option controls
-  /// whether to use 32-bit types to emulate, if a scalar type of a certain
-  /// bitwidth is not supported in the target environment. This requires the
-  /// runtime to also feed in data with a matched bitwidth and layout for
-  /// interface types. The runtime can do that by inspecting the SPIR-V
-  /// module.
+  /// whether to use 32-bit types to emulate < 32-bits-wide scalars, if a scalar
+  /// type of a certain bitwidth is not supported in the target environment.
+  /// This requires the runtime to also feed in data with a matched bitwidth and
+  /// layout for interface types. The runtime can do that by inspecting the
+  /// SPIR-V module.
   /// If the original scalar type has less than 32-bit, a multiple of its
   /// values will be packed into one 32-bit value to be memory efficient.
-  bool emulateNon32BitScalarTypes{true};
+  bool emulateLT32BitScalarTypes{true};
   /// Use 64-bit integers to convert index types.
   bool use64bitIndex{false};

diff  --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 2452928dd4503..cf65beb924fb7 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -1031,7 +1031,7 @@ struct ConvertArithToSPIRVPass
     auto target = SPIRVConversionTarget::get(targetAttr);
     SPIRVConversionOptions options;
-    options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
+    options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
     options.enableFastMathMode = this->enableFastMath;
     SPIRVTypeConverter typeConverter(targetAttr, options);

diff  --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 0d1e8b8079465..d8aecae257b46 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -41,7 +41,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
   SPIRVConversionOptions options;
-  options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
+  options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
   SPIRVTypeConverter typeConverter(targetAttr, options);
   RewritePatternSet patterns(context);

diff  --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index a82ba5dd12a5d..9fffc5e3182e9 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -40,7 +40,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
   SPIRVConversionOptions options;
-  options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
+  options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
   SPIRVTypeConverter typeConverter(targetAttr, options);
   RewritePatternSet patterns(context);

diff  --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index 6b1145c464787..313172614268d 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -38,7 +38,7 @@ class ConvertTensorToSPIRVPass
     SPIRVConversionOptions options;
-    options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
+    options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
     SPIRVTypeConverter typeConverter(targetAttr, options);
     RewritePatternSet patterns(context);

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2514cfe0301a1..286ff0b7eff2d 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -220,9 +220,16 @@ static Type convertScalarType(const spirv::TargetEnv &targetEnv,
   // Otherwise we need to adjust the type, which really means adjusting the
   // bitwidth given this is a scalar type.
+  if (!options.emulateLT32BitScalarTypes)
+    return nullptr;
-  if (!options.emulateNon32BitScalarTypes)
+  // We only emulate narrower scalar types here and do not truncate results.
+  if (type.getIntOrFloatBitWidth() > 32) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type
+               << " not converted to 32-bit for SPIR-V to avoid truncation\n");
     return nullptr;
+  }
   if (auto floatType = type.dyn_cast<FloatType>()) {
     LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");

diff  --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 967adbc84a3bb..f6e84e80bbf51 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -49,7 +49,15 @@ func.func @int_vector4_invalid(%arg0: vector<2xi16>) {
 // -----
-func.func @unsupported_constant_0() {
+func.func @unsupported_constant_i64_0() {
+  // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
+  %0 = arith.constant 0 : i64
+  return
+// -----
+func.func @unsupported_constant_i64_1() {
   // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
   %0 = arith.constant 4294967296 : i64 // 2^32
@@ -57,16 +65,68 @@ func.func @unsupported_constant_0() {
 // -----
-func.func @unsupported_constant_1() {
+func.func @unsupported_constant_vector_2xi64_0() {
+  // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
+  %1 = arith.constant dense<0> : vector<2xi64>
+  return
+// -----
+func.func @unsupported_constant_f64_0() {
   // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
-  %1 = arith.constant -2147483649 : i64 // -2^31 - 1
+  %1 = arith.constant 0.0 : f64
 // -----
-func.func @unsupported_constant_2() {
+func.func @unsupported_constant_vector_2xf64_0() {
   // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
-  %2 = arith.constant -2147483649 : i64 // -2^31 - 1
+  %1 = arith.constant dense<0.0> : vector<2xf64>
+// -----
+func.func @unsupported_constant_tensor_2xf64_0() {
+  // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
+  %1 = arith.constant dense<0.0> : tensor<2xf64>
+  return
+// Type emulation
+// -----
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+// Check that we do not emualte i64 by truncating to i32.
+func.func @unsupported_i64(%arg0: i64) {
+  // expected-error at +1 {{failed to legalize operation 'arith.addi'}}
+  %2 = arith.addi %arg0, %arg0: i64
+  return
+} // end module
+// -----
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+// Check that we do not emualte f64 by truncating to i32.
+func.func @unsupported_f64(%arg0: f64) {
+  // expected-error at +1 {{failed to legalize operation 'arith.addf'}}
+  %2 = arith.addf %arg0, %arg0: f64
+  return
+} // end module

diff  --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index df6806a0e4bd1..d561cd2c26f29 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -513,7 +513,7 @@ func.func @constant_size1() {
 // -----
-// Check that constants are converted to 32-bit when no special capability.
+// Check that constants are widened to 32-bit when no special capability.
 module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
 } {
@@ -533,51 +533,26 @@ func.func @constant_16bit() {
-// CHECK-LABEL: @constant_64bit
-func.func @constant_64bit() {
-  // CHECK: spirv.Constant 4 : i32
-  %0 = arith.constant 4 : i64
-  // CHECK: spirv.Constant 5.000000e+00 : f32
-  %1 = arith.constant 5.0 : f64
-  // CHECK: spirv.Constant dense<[2, 3]> : vector<2xi32>
-  %2 = arith.constant dense<[2, 3]> : vector<2xi64>
-  // CHECK: spirv.Constant dense<4.000000e+00> : tensor<5xf32> : !spirv.array<5 x f32>
-  %3 = arith.constant dense<4.0> : tensor<5xf64>
-  // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spirv.array<4 x f32>
-  %4 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16>
-  return
 // CHECK-LABEL: @constant_size1
 func.func @constant_size1() {
   // CHECK: spirv.Constant 4 : i32
-  %0 = arith.constant dense<4> : vector<1xi64>
+  %0 = arith.constant dense<4> : vector<1xi16>
   // CHECK: spirv.Constant 5.000000e+00 : f32
-  %1 = arith.constant dense<5.0> : tensor<1xf64>
+  %1 = arith.constant dense<5.0> : tensor<1xf16>
 // CHECK-LABEL: @corner_cases
 func.func @corner_cases() {
-  // CHECK: %{{.*}} = spirv.Constant -1 : i32
-  %0 = arith.constant 4294967295  : i64 // 2^32 - 1
-  // CHECK: %{{.*}} = spirv.Constant 2147483647 : i32
-  %1 = arith.constant 2147483647  : i64 // 2^31 - 1
-  // CHECK: %{{.*}} = spirv.Constant -2147483648 : i32
-  %2 = arith.constant 2147483648  : i64 // 2^31
-  // CHECK: %{{.*}} = spirv.Constant -2147483648 : i32
-  %3 = arith.constant -2147483648 : i64 // -2^31
-  // CHECK: %{{.*}} = spirv.Constant -1 : i32
-  %5 = arith.constant -1 : i64
+ // CHECK: %{{.*}} = spirv.Constant -1 : i32
+  %5 = arith.constant -1 : i16
   // CHECK: %{{.*}} = spirv.Constant -2 : i32
-  %6 = arith.constant -2 : i64
+  %6 = arith.constant -2 : i16
   // CHECK: %{{.*}} = spirv.Constant -1 : i32
   %7 = arith.constant -1 : index
   // CHECK: %{{.*}} = spirv.Constant -2 : i32
   %8 = arith.constant -2 : index
   // CHECK: spirv.Constant false
   %9 = arith.constant false
   // CHECK: spirv.Constant true
@@ -903,29 +878,13 @@ module attributes {
 } {
 // CHECK-LABEL: @fptrunc1
-// CHECK-SAME: %[[A:.*]]: f64
-func.func @fptrunc1(%arg0 : f64) -> f16 {
-  // CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[A]] : f64 to f32
-  // CHECK-NEXT: spirv.FConvert %[[ARG]] : f32 to f16
-  %0 = arith.truncf %arg0 : f64 to f16
-  return %0: f16
-// CHECK-LABEL: @fptrunc2
 // CHECK-SAME: %[[ARG:.*]]: f32
-func.func @fptrunc2(%arg0: f32) -> f16 {
+func.func @fptrunc1(%arg0: f32) -> f16 {
   // CHECK-NEXT: spirv.FConvert %[[ARG]] : f32 to f16
   %0 = arith.truncf %arg0 : f32 to f16
   return %0: f16
-// CHECK-LABEL: @sitofp
-func.func @sitofp(%arg0 : i64) -> f64 {
-  // CHECK: spirv.ConvertSToF %{{.*}} : i32 to f32
-  %0 = arith.sitofp %arg0 : i64 to f64
-  return %0: f64
 } // end module
 // -----
@@ -1209,11 +1168,9 @@ func.func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) {
 // CHECK-LABEL: @float_scalar
-func.func @float_scalar(%arg0: f16, %arg1: f64) {
+func.func @float_scalar(%arg0: f16) {
   // CHECK: spirv.FAdd %{{.*}}, %{{.*}}: f32
   %0 = arith.addf %arg0, %arg0: f16
-  // CHECK: spirv.FMul %{{.*}}, %{{.*}}: f32
-  %1 = arith.mulf %arg1, %arg1: f64
@@ -1513,74 +1470,6 @@ func.func @constant_64bit() {
 // -----
-// Check that constants are converted to 32-bit when no special capability.
-module attributes {
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
-} {
-// CHECK-LABEL: @constant_16bit
-func.func @constant_16bit() {
-  // CHECK: spirv.Constant 4 : i32
-  %0 = arith.constant 4 : i16
-  // CHECK: spirv.Constant 5.000000e+00 : f32
-  %1 = arith.constant 5.0 : f16
-  // CHECK: spirv.Constant dense<[2, 3]> : vector<2xi32>
-  %2 = arith.constant dense<[2, 3]> : vector<2xi16>
-  // CHECK: spirv.Constant dense<4.000000e+00> : tensor<5xf32> : !spirv.array<5 x f32>
-  %3 = arith.constant dense<4.0> : tensor<5xf16>
-  // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spirv.array<4 x f32>
-  %4 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16>
-  return
-// CHECK-LABEL: @constant_64bit
-func.func @constant_64bit() {
-  // CHECK: spirv.Constant 4 : i32
-  %0 = arith.constant 4 : i64
-  // CHECK: spirv.Constant 5.000000e+00 : f32
-  %1 = arith.constant 5.0 : f64
-  // CHECK: spirv.Constant dense<[2, 3]> : vector<2xi32>
-  %2 = arith.constant dense<[2, 3]> : vector<2xi64>
-  // CHECK: spirv.Constant dense<4.000000e+00> : tensor<5xf32> : !spirv.array<5 x f32>
-  %3 = arith.constant dense<4.0> : tensor<5xf64>
-  // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spirv.array<4 x f32>
-  %4 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16>
-  return
-// CHECK-LABEL: @corner_cases
-func.func @corner_cases() {
-  // CHECK: %{{.*}} = spirv.Constant -1 : i32
-  %0 = arith.constant 4294967295  : i64 // 2^32 - 1
-  // CHECK: %{{.*}} = spirv.Constant 2147483647 : i32
-  %1 = arith.constant 2147483647  : i64 // 2^31 - 1
-  // CHECK: %{{.*}} = spirv.Constant -2147483648 : i32
-  %2 = arith.constant 2147483648  : i64 // 2^31
-  // CHECK: %{{.*}} = spirv.Constant -2147483648 : i32
-  %3 = arith.constant -2147483648 : i64 // -2^31
-  // CHECK: %{{.*}} = spirv.Constant -1 : i32
-  %5 = arith.constant -1 : i64
-  // CHECK: %{{.*}} = spirv.Constant -2 : i32
-  %6 = arith.constant -2 : i64
-  // CHECK: %{{.*}} = spirv.Constant -1 : i32
-  %7 = arith.constant -1 : index
-  // CHECK: %{{.*}} = spirv.Constant -2 : i32
-  %8 = arith.constant -2 : index
-  // CHECK: spirv.Constant false
-  %9 = arith.constant false
-  // CHECK: spirv.Constant true
-  %10 = arith.constant true
-  return
-} // end module
-// -----
 // std cast ops
@@ -1847,39 +1736,3 @@ func.func @fpext2(%arg0 : f32) -> f64 {
 } // end module
-// -----
-// Checks that cast types will be adjusted when missing special capabilities for
-// certain non-32-bit scalar types.
-module attributes {
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Float16], []>, #spirv.resource_limits<>>
-} {
-// CHECK-LABEL: @fptrunc1
-// CHECK-SAME: %[[A:.*]]: f64
-func.func @fptrunc1(%arg0 : f64) -> f16 {
-  // CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[A]] : f64 to f32
-  // CHECK-NEXT: spirv.FConvert %[[ARG]] : f32 to f16
-  %0 = arith.truncf %arg0 : f64 to f16
-  return %0: f16
-// CHECK-LABEL: @fptrunc2
-// CHECK-SAME: %[[ARG:.*]]: f32
-func.func @fptrunc2(%arg0: f32) -> f16 {
-  // CHECK-NEXT: spirv.FConvert %[[ARG]] : f32 to f16
-  %0 = arith.truncf %arg0 : f32 to f16
-  return %0: f16
-// CHECK-LABEL: @sitofp
-func.func @sitofp(%arg0 : i64) -> f64 {
-  // CHECK: spirv.ConvertSToF %{{.*}} : i32 to f32
-  %0 = arith.sitofp %arg0 : i64 to f64
-  return %0: f64
-} // end module
-// -----

diff  --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 4f1cd09efec30..d207ecd71c3cb 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
-// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-non-32-bit-scalar-types=false" %s -o - | FileCheck %s --check-prefix=NOEMU
+// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
+// RUN:   FileCheck %s --check-prefix=NOEMU
 // Integer types
@@ -15,7 +16,7 @@ module attributes {
 // CHECK-SAME: i32
 // CHECK-SAME: si32
 // CHECK-SAME: ui32
-// NOEMU-LABEL: func @integer8
+// NOEMU-LABEL: func.func @integer8
 // NOEMU-SAME: i8
 // NOEMU-SAME: si8
 // NOEMU-SAME: ui8
@@ -25,16 +26,17 @@ func.func @integer8(%arg0: i8, %arg1: si8, %arg2: ui8) { return }
 // CHECK-SAME: i32
 // CHECK-SAME: si32
 // CHECK-SAME: ui32
-// NOEMU-LABEL: func @integer16
+// NOEMU-LABEL: func.func @integer16
 // NOEMU-SAME: i16
 // NOEMU-SAME: si16
 // NOEMU-SAME: ui16
 func.func @integer16(%arg0: i16, %arg1: si16, %arg2: ui16) { return }
-// CHECK-LABEL: spirv.func @integer64
-// CHECK-SAME: i32
-// CHECK-SAME: si32
-// CHECK-SAME: ui32
+// We do not truncate 64-bit types to 32-bit ones.
+// CHECK-LABEL: func.func @integer64
+// CHECK-SAME: i64
+// CHECK-SAME: si64
+// CHECK-SAME: ui64
 // NOEMU-LABEL: func @integer64
 // NOEMU-SAME: i64
 // NOEMU-SAME: si64
@@ -131,13 +133,13 @@ module attributes {
 // CHECK-LABEL: spirv.func @float16
 // CHECK-SAME: f32
-// NOEMU-LABEL: func @float16
+// NOEMU-LABEL: func.func @float16
 // NOEMU-SAME: f16
 func.func @float16(%arg0: f16) { return }
-// CHECK-LABEL: spirv.func @float64
-// CHECK-SAME: f32
-// NOEMU-LABEL: func @float64
+// CHECK-LABEL: func.func @float64
+// CHECK-SAME: f64
+// NOEMU-LABEL: func.func @float64
 // NOEMU-SAME: f64
 func.func @float64(%arg0: f64) { return }
@@ -184,7 +186,7 @@ func.func @bf16_type(%arg0: bf16) { return }
 // Check that capabilities for scalar types affects vector types too: no special
-// capabilities available means using turning element types to 32-bit.
+// capabilities available means widening element types to 32-bit.
 module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
 } {
@@ -192,19 +194,15 @@ module attributes {
 // CHECK-LABEL: spirv.func @int_vector
 // CHECK-SAME: vector<2xi32>
 // CHECK-SAME: vector<3xsi32>
-// CHECK-SAME: vector<4xui32>
 func.func @int_vector(
   %arg0: vector<2xi8>,
-  %arg1: vector<3xsi16>,
-  %arg2: vector<4xui64>
+  %arg1: vector<3xsi16>
 ) { return }
 // CHECK-LABEL: spirv.func @float_vector
 // CHECK-SAME: vector<2xf32>
-// CHECK-SAME: vector<3xf32>
 func.func @float_vector(
-  %arg0: vector<2xf16>,
-  %arg1: vector<3xf64>
+  %arg0: vector<2xf16>
 ) { return }
 // CHECK-LABEL: spirv.func @one_element_vector
@@ -389,33 +387,35 @@ func.func @memref_16bit_Input(%arg3: memref<16xf16, #spirv.storage_class<Input>>
 // NOEMU-SAME: memref<16xf16, #spirv.storage_class<Output>>
 func.func @memref_16bit_Output(%arg4: memref<16xf16, #spirv.storage_class<Output>>) { return }
-// CHECK-LABEL: spirv.func @memref_64bit_StorageBuffer
-// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<32 x i32, stride=4> [0])>, StorageBuffer>
-// NOEMU-LABEL: func @memref_64bit_StorageBuffer
+// We do not truncate i64 to i32.
+// CHECK-LABEL: func.func @memref_64bit_StorageBuffer
+// CHECK-SAME: memref<16xi64, #spirv.storage_class<StorageBuffer>>
+// NOEMU-LABEL: func.func @memref_64bit_StorageBuffer
 // NOEMU-SAME: memref<16xi64, #spirv.storage_class<StorageBuffer>>
 func.func @memref_64bit_StorageBuffer(%arg0: memref<16xi64, #spirv.storage_class<StorageBuffer>>) { return }
-// CHECK-LABEL: spirv.func @memref_64bit_Uniform
-// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<32 x si32, stride=4> [0])>, Uniform>
-// NOEMU-LABEL: func @memref_64bit_Uniform
+// CHECK-LABEL: func.func @memref_64bit_Uniform
+// CHECK-SAME: memref<16xsi64, #spirv.storage_class<Uniform>>
+// NOEMU-LABEL: func.func @memref_64bit_Uniform
 // NOEMU-SAME: memref<16xsi64, #spirv.storage_class<Uniform>>
 func.func @memref_64bit_Uniform(%arg0: memref<16xsi64, #spirv.storage_class<Uniform>>) { return }
-// CHECK-LABEL: spirv.func @memref_64bit_PushConstant
-// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<32 x ui32, stride=4> [0])>, PushConstant>
-// NOEMU-LABEL: func @memref_64bit_PushConstant
+// CHECK-LABEL: func.func @memref_64bit_PushConstant
+// CHECK-SAME: memref<16xui64, #spirv.storage_class<PushConstant>>
+// NOEMU-LABEL: func.func @memref_64bit_PushConstant
 // NOEMU-SAME: memref<16xui64, #spirv.storage_class<PushConstant>>
 func.func @memref_64bit_PushConstant(%arg0: memref<16xui64, #spirv.storage_class<PushConstant>>) { return }
-// CHECK-LABEL: spirv.func @memref_64bit_Input
-// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<32 x f32>)>, Input>
-// NOEMU-LABEL: func @memref_64bit_Input
+// CHECK-LABEL: func.func @memref_64bit_Input
+// CHECK-SAME: memref<16xf64, #spirv.storage_class<Input>>
+// NOEMU-LABEL: func.func @memref_64bit_Input
 // NOEMU-SAME: memref<16xf64, #spirv.storage_class<Input>>
 func.func @memref_64bit_Input(%arg3: memref<16xf64, #spirv.storage_class<Input>>) { return }
-// CHECK-LABEL: spirv.func @memref_64bit_Output
-// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<32 x f32>)>, Output>
-// NOEMU-LABEL: func @memref_64bit_Output
+// CHECK-LABEL: func.func @memref_64bit_Output
+// CHECK-SAME: memref<16xf64, #spirv.storage_class<Output>>
+// NOEMU-LABEL: func.func @memref_64bit_Output
 // NOEMU-SAME: memref<16xf64, #spirv.storage_class<Output>>
 func.func @memref_64bit_Output(%arg4: memref<16xf64, #spirv.storage_class<Output>>) { return }
@@ -791,9 +791,7 @@ module attributes {
 // CHECK-SAME: !spirv.array<32 x i32>
 // CHECK-SAME: !spirv.array<32 x i32>
 // CHECK-SAME: !spirv.array<32 x i32>
-// CHECK-SAME: !spirv.array<32 x i32>
 func.func @int_tensor_types(
-  %arg0: tensor<8x4xi64>,
   %arg1: tensor<8x4xi32>,
   %arg2: tensor<8x4xi16>,
   %arg3: tensor<8x4xi8>
@@ -802,9 +800,7 @@ func.func @int_tensor_types(
 // CHECK-LABEL: spirv.func @float_tensor_types
 // CHECK-SAME: !spirv.array<32 x f32>
 // CHECK-SAME: !spirv.array<32 x f32>
-// CHECK-SAME: !spirv.array<32 x f32>
 func.func @float_tensor_types(
-  %arg0: tensor<8x4xf64>,
   %arg1: tensor<8x4xf32>,
   %arg2: tensor<8x4xf16>
 ) { return }


More information about the Mlir-commits mailing list