[Mlir-commits] [mlir] c064545 - [mlir][spirv] Do not truncate i/f64 -> i/f32 in SPIRVConversion
Jakub Kuderski
llvmlistbot at llvm.org
Fri Nov 4 12:11:43 PDT 2022
Author: Jakub Kuderski
Date: 2022-11-04T15:10:28-04:00
New Revision: c064545403917bedd450e07209e7870f1773f90f
URL: https://github.com/llvm/llvm-project/commit/c064545403917bedd450e07209e7870f1773f90f
DIFF: https://github.com/llvm/llvm-project/commit/c064545403917bedd450e07209e7870f1773f90f.diff
LOG: [mlir][spirv] Do not truncate i/f64 -> i/f32 in SPIRVConversion
This truncation can be unexpected and break program behavior.
Dedicated emulation passes should be used instead.
Also rename pass options to "emulate-lt-32-bit-scalar-types".
Fixes: https://github.com/llvm/llvm-project/issues/57917
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D137115
Added:
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 66ac9eedf1bfb..cef82f1e29ff1 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -118,10 +118,10 @@ def ConvertArithToSPIRV : Pass<"convert-arith-to-spirv"> {
let constructor = "mlir::arith::createConvertArithToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
let options = [
- Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types",
+ Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
- "Emulate non-32-bit scalar types with 32-bit ones if "
- "missing native support">,
+ "Emulate narrower scalar types with 32-bit ones if not supported by "
+ "the target">,
Option<"enableFastMath", "enable-fast-math",
"bool", /*default=*/"false",
"Enable fast math mode (assuming no NaN and infinity for floating "
@@ -259,10 +259,10 @@ def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv"> {
let constructor = "mlir::createConvertControlFlowToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
let options = [
- Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types",
+ Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
- "Emulate non-32-bit scalar types with 32-bit ones if "
- "missing native support">
+ "Emulate narrower scalar types with 32-bit ones if not supported by"
+ " the target">
];
}
@@ -320,10 +320,10 @@ def ConvertFuncToSPIRV : Pass<"convert-func-to-spirv"> {
let constructor = "mlir::createConvertFuncToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
let options = [
- Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types",
+ Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
- "Emulate non-32-bit scalar types with 32-bit ones if "
- "missing native support">
+ "Emulate narrower scalar types with 32-bit ones if not supported by"
+ " the target">
];
}
@@ -815,10 +815,10 @@ def ConvertTensorToSPIRV : Pass<"convert-tensor-to-spirv"> {
let constructor = "mlir::createConvertTensorToSPIRVPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
let options = [
- Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types",
+ Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
- "Emulate non-32-bit scalar types with 32-bit ones if "
- "missing native support">
+ "Emulate narrower scalar types with 32-bit ones if not supported by"
+ " the target">
];
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 9b480f6cc9e3a..7d362526cc22f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -30,21 +30,21 @@ struct SPIRVConversionOptions {
/// The number of bits to store a boolean value.
unsigned boolNumBits{8};
- /// Whether to emulate non-32-bit scalar types with 32-bit scalar types if
- /// no native support.
+ /// Whether to emulate narrower scalar types with 32-bit scalar types if not
+ /// supported by the target.
///
/// Non-32-bit scalar types require special hardware support that may not
/// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar
/// types require special capabilities or extensions. This option controls
- /// whether to use 32-bit types to emulate, if a scalar type of a certain
- /// bitwidth is not supported in the target environment. This requires the
- /// runtime to also feed in data with a matched bitwidth and layout for
- /// interface types. The runtime can do that by inspecting the SPIR-V
- /// module.
+ /// whether to use 32-bit types to emulate < 32-bits-wide scalars, if a scalar
+ /// type of a certain bitwidth is not supported in the target environment.
+ /// This requires the runtime to also feed in data with a matched bitwidth and
+ /// layout for interface types. The runtime can do that by inspecting the
+ /// SPIR-V module.
///
/// If the original scalar type has less than 32-bit, a multiple of its
/// values will be packed into one 32-bit value to be memory efficient.
- bool emulateNon32BitScalarTypes{true};
+ bool emulateLT32BitScalarTypes{true};
/// Use 64-bit integers to convert index types.
bool use64bitIndex{false};
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 2452928dd4503..cf65beb924fb7 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -1031,7 +1031,7 @@ struct ConvertArithToSPIRVPass
auto target = SPIRVConversionTarget::get(targetAttr);
SPIRVConversionOptions options;
- options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
+ options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
options.enableFastMathMode = this->enableFastMath;
SPIRVTypeConverter typeConverter(targetAttr, options);
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 0d1e8b8079465..d8aecae257b46 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -41,7 +41,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
SPIRVConversionTarget::get(targetAttr);
SPIRVConversionOptions options;
- options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
+ options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index a82ba5dd12a5d..9fffc5e3182e9 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -40,7 +40,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
SPIRVConversionTarget::get(targetAttr);
SPIRVConversionOptions options;
- options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
+ options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index 6b1145c464787..313172614268d 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -38,7 +38,7 @@ class ConvertTensorToSPIRVPass
SPIRVConversionTarget::get(targetAttr);
SPIRVConversionOptions options;
- options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
+ options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2514cfe0301a1..286ff0b7eff2d 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -220,9 +220,16 @@ static Type convertScalarType(const spirv::TargetEnv &targetEnv,
// Otherwise we need to adjust the type, which really means adjusting the
// bitwidth given this is a scalar type.
+ if (!options.emulateLT32BitScalarTypes)
+ return nullptr;
- if (!options.emulateNon32BitScalarTypes)
+ // We only emulate narrower scalar types here and do not truncate results.
+ if (type.getIntOrFloatBitWidth() > 32) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type
+ << " not converted to 32-bit for SPIR-V to avoid truncation\n");
return nullptr;
+ }
if (auto floatType = type.dyn_cast<FloatType>()) {
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 967adbc84a3bb..f6e84e80bbf51 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -49,7 +49,15 @@ func.func @int_vector4_invalid(%arg0: vector<2xi16>) {
// -----
-func.func @unsupported_constant_0() {
+func.func @unsupported_constant_i64_0() {
+ // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
+ %0 = arith.constant 0 : i64
+ return
+}
+
+// -----
+
+func.func @unsupported_constant_i64_1() {
// expected-error @+1 {{failed to legalize operation 'arith.constant'}}
%0 = arith.constant 4294967296 : i64 // 2^32
return
@@ -57,16 +65,68 @@ func.func @unsupported_constant_0() {
// -----
-func.func @unsupported_constant_1() {
+func.func @unsupported_constant_vector_2xi64_0() {
+ // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
+ %1 = arith.constant dense<0> : vector<2xi64>
+ return
+}
+
+// -----
+
+func.func @unsupported_constant_f64_0() {
// expected-error @+1 {{failed to legalize operation 'arith.constant'}}
- %1 = arith.constant -2147483649 : i64 // -2^31 - 1
+ %1 = arith.constant 0.0 : f64
return
}
// -----
-func.func @unsupported_constant_2() {
+func.func @unsupported_constant_vector_2xf64_0() {
// expected-error @+1 {{failed to legalize operation 'arith.constant'}}
- %2 = arith.constant -2147483649 : i64 // -2^31 - 1
+ %1 = arith.constant dense<0.0> : vector<2xf64>
return
}
+
+// -----
+
+func.func @unsupported_constant_tensor_2xf64_0() {
+ // expected-error @+1 {{failed to legalize operation 'arith.constant'}}
+ %1 = arith.constant dense<0.0> : tensor<2xf64>
+ return
+}
+
+///===----------------------------------------------------------------------===//
+// Type emulation
+//===----------------------------------------------------------------------===//
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// Check that we do not emualte i64 by truncating to i32.
+func.func @unsupported_i64(%arg0: i64) {
+ // expected-error at +1 {{failed to legalize operation 'arith.addi'}}
+ %2 = arith.addi %arg0, %arg0: i64
+ return
+}
+
+} // end module
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// Check that we do not emualte f64 by truncating to i32.
+func.func @unsupported_f64(%arg0: f64) {
+ // expected-error at +1 {{failed to legalize operation 'arith.addf'}}
+ %2 = arith.addf %arg0, %arg0: f64
+ return
+}
+
+} // end module
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index df6806a0e4bd1..d561cd2c26f29 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -513,7 +513,7 @@ func.func @constant_size1() {
// -----
-// Check that constants are converted to 32-bit when no special capability.
+// Check that constants are widened to 32-bit when no special capability.
module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
} {
@@ -533,51 +533,26 @@ func.func @constant_16bit() {
return
}
-// CHECK-LABEL: @constant_64bit
-func.func @constant_64bit() {
- // CHECK: spirv.Constant 4 : i32
- %0 = arith.constant 4 : i64
- // CHECK: spirv.Constant 5.000000e+00 : f32
- %1 = arith.constant 5.0 : f64
- // CHECK: spirv.Constant dense<[2, 3]> : vector<2xi32>
- %2 = arith.constant dense<[2, 3]> : vector<2xi64>
- // CHECK: spirv.Constant dense<4.000000e+00> : tensor<5xf32> : !spirv.array<5 x f32>
- %3 = arith.constant dense<4.0> : tensor<5xf64>
- // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spirv.array<4 x f32>
- %4 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16>
- return
-}
-
// CHECK-LABEL: @constant_size1
func.func @constant_size1() {
// CHECK: spirv.Constant 4 : i32
- %0 = arith.constant dense<4> : vector<1xi64>
+ %0 = arith.constant dense<4> : vector<1xi16>
// CHECK: spirv.Constant 5.000000e+00 : f32
- %1 = arith.constant dense<5.0> : tensor<1xf64>
+ %1 = arith.constant dense<5.0> : tensor<1xf16>
return
}
// CHECK-LABEL: @corner_cases
func.func @corner_cases() {
- // CHECK: %{{.*}} = spirv.Constant -1 : i32
- %0 = arith.constant 4294967295 : i64 // 2^32 - 1
- // CHECK: %{{.*}} = spirv.Constant 2147483647 : i32
- %1 = arith.constant 2147483647 : i64 // 2^31 - 1
- // CHECK: %{{.*}} = spirv.Constant -2147483648 : i32
- %2 = arith.constant 2147483648 : i64 // 2^31
- // CHECK: %{{.*}} = spirv.Constant -2147483648 : i32
- %3 = arith.constant -2147483648 : i64 // -2^31
-
- // CHECK: %{{.*}} = spirv.Constant -1 : i32
- %5 = arith.constant -1 : i64
+ // CHECK: %{{.*}} = spirv.Constant -1 : i32
+ %5 = arith.constant -1 : i16
// CHECK: %{{.*}} = spirv.Constant -2 : i32
- %6 = arith.constant -2 : i64
+ %6 = arith.constant -2 : i16
// CHECK: %{{.*}} = spirv.Constant -1 : i32
%7 = arith.constant -1 : index
// CHECK: %{{.*}} = spirv.Constant -2 : i32
%8 = arith.constant -2 : index
-
// CHECK: spirv.Constant false
%9 = arith.constant false
// CHECK: spirv.Constant true
@@ -903,29 +878,13 @@ module attributes {
} {
// CHECK-LABEL: @fptrunc1
-// CHECK-SAME: %[[A:.*]]: f64
-func.func @fptrunc1(%arg0 : f64) -> f16 {
- // CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[A]] : f64 to f32
- // CHECK-NEXT: spirv.FConvert %[[ARG]] : f32 to f16
- %0 = arith.truncf %arg0 : f64 to f16
- return %0: f16
-}
-
-// CHECK-LABEL: @fptrunc2
// CHECK-SAME: %[[ARG:.*]]: f32
-func.func @fptrunc2(%arg0: f32) -> f16 {
+func.func @fptrunc1(%arg0: f32) -> f16 {
// CHECK-NEXT: spirv.FConvert %[[ARG]] : f32 to f16
%0 = arith.truncf %arg0 : f32 to f16
return %0: f16
}
-// CHECK-LABEL: @sitofp
-func.func @sitofp(%arg0 : i64) -> f64 {
- // CHECK: spirv.ConvertSToF %{{.*}} : i32 to f32
- %0 = arith.sitofp %arg0 : i64 to f64
- return %0: f64
-}
-
} // end module
// -----
@@ -1209,11 +1168,9 @@ func.func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) {
}
// CHECK-LABEL: @float_scalar
-func.func @float_scalar(%arg0: f16, %arg1: f64) {
+func.func @float_scalar(%arg0: f16) {
// CHECK: spirv.FAdd %{{.*}}, %{{.*}}: f32
%0 = arith.addf %arg0, %arg0: f16
- // CHECK: spirv.FMul %{{.*}}, %{{.*}}: f32
- %1 = arith.mulf %arg1, %arg1: f64
return
}
@@ -1513,74 +1470,6 @@ func.func @constant_64bit() {
// -----
-// Check that constants are converted to 32-bit when no special capability.
-module attributes {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
-} {
-
-// CHECK-LABEL: @constant_16bit
-func.func @constant_16bit() {
- // CHECK: spirv.Constant 4 : i32
- %0 = arith.constant 4 : i16
- // CHECK: spirv.Constant 5.000000e+00 : f32
- %1 = arith.constant 5.0 : f16
- // CHECK: spirv.Constant dense<[2, 3]> : vector<2xi32>
- %2 = arith.constant dense<[2, 3]> : vector<2xi16>
- // CHECK: spirv.Constant dense<4.000000e+00> : tensor<5xf32> : !spirv.array<5 x f32>
- %3 = arith.constant dense<4.0> : tensor<5xf16>
- // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spirv.array<4 x f32>
- %4 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16>
- return
-}
-
-// CHECK-LABEL: @constant_64bit
-func.func @constant_64bit() {
- // CHECK: spirv.Constant 4 : i32
- %0 = arith.constant 4 : i64
- // CHECK: spirv.Constant 5.000000e+00 : f32
- %1 = arith.constant 5.0 : f64
- // CHECK: spirv.Constant dense<[2, 3]> : vector<2xi32>
- %2 = arith.constant dense<[2, 3]> : vector<2xi64>
- // CHECK: spirv.Constant dense<4.000000e+00> : tensor<5xf32> : !spirv.array<5 x f32>
- %3 = arith.constant dense<4.0> : tensor<5xf64>
- // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> : !spirv.array<4 x f32>
- %4 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf16>
- return
-}
-
-// CHECK-LABEL: @corner_cases
-func.func @corner_cases() {
- // CHECK: %{{.*}} = spirv.Constant -1 : i32
- %0 = arith.constant 4294967295 : i64 // 2^32 - 1
- // CHECK: %{{.*}} = spirv.Constant 2147483647 : i32
- %1 = arith.constant 2147483647 : i64 // 2^31 - 1
- // CHECK: %{{.*}} = spirv.Constant -2147483648 : i32
- %2 = arith.constant 2147483648 : i64 // 2^31
- // CHECK: %{{.*}} = spirv.Constant -2147483648 : i32
- %3 = arith.constant -2147483648 : i64 // -2^31
-
- // CHECK: %{{.*}} = spirv.Constant -1 : i32
- %5 = arith.constant -1 : i64
- // CHECK: %{{.*}} = spirv.Constant -2 : i32
- %6 = arith.constant -2 : i64
- // CHECK: %{{.*}} = spirv.Constant -1 : i32
- %7 = arith.constant -1 : index
- // CHECK: %{{.*}} = spirv.Constant -2 : i32
- %8 = arith.constant -2 : index
-
-
- // CHECK: spirv.Constant false
- %9 = arith.constant false
- // CHECK: spirv.Constant true
- %10 = arith.constant true
-
- return
-}
-
-} // end module
-
-// -----
-
//===----------------------------------------------------------------------===//
// std cast ops
//===----------------------------------------------------------------------===//
@@ -1847,39 +1736,3 @@ func.func @fpext2(%arg0 : f32) -> f64 {
}
} // end module
-
-// -----
-
-// Checks that cast types will be adjusted when missing special capabilities for
-// certain non-32-bit scalar types.
-module attributes {
- spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Float16], []>, #spirv.resource_limits<>>
-} {
-
-// CHECK-LABEL: @fptrunc1
-// CHECK-SAME: %[[A:.*]]: f64
-func.func @fptrunc1(%arg0 : f64) -> f16 {
- // CHECK: %[[ARG:.+]] = builtin.unrealized_conversion_cast %[[A]] : f64 to f32
- // CHECK-NEXT: spirv.FConvert %[[ARG]] : f32 to f16
- %0 = arith.truncf %arg0 : f64 to f16
- return %0: f16
-}
-
-// CHECK-LABEL: @fptrunc2
-// CHECK-SAME: %[[ARG:.*]]: f32
-func.func @fptrunc2(%arg0: f32) -> f16 {
- // CHECK-NEXT: spirv.FConvert %[[ARG]] : f32 to f16
- %0 = arith.truncf %arg0 : f32 to f16
- return %0: f16
-}
-
-// CHECK-LABEL: @sitofp
-func.func @sitofp(%arg0 : i64) -> f64 {
- // CHECK: spirv.ConvertSToF %{{.*}} : i32 to f32
- %0 = arith.sitofp %arg0 : i64 to f64
- return %0: f64
-}
-
-} // end module
-
-// -----
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 4f1cd09efec30..d207ecd71c3cb 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
-// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-non-32-bit-scalar-types=false" %s -o - | FileCheck %s --check-prefix=NOEMU
+// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
+// RUN: FileCheck %s --check-prefix=NOEMU
//===----------------------------------------------------------------------===//
// Integer types
@@ -15,7 +16,7 @@ module attributes {
// CHECK-SAME: i32
// CHECK-SAME: si32
// CHECK-SAME: ui32
-// NOEMU-LABEL: func @integer8
+// NOEMU-LABEL: func.func @integer8
// NOEMU-SAME: i8
// NOEMU-SAME: si8
// NOEMU-SAME: ui8
@@ -25,16 +26,17 @@ func.func @integer8(%arg0: i8, %arg1: si8, %arg2: ui8) { return }
// CHECK-SAME: i32
// CHECK-SAME: si32
// CHECK-SAME: ui32
-// NOEMU-LABEL: func @integer16
+// NOEMU-LABEL: func.func @integer16
// NOEMU-SAME: i16
// NOEMU-SAME: si16
// NOEMU-SAME: ui16
func.func @integer16(%arg0: i16, %arg1: si16, %arg2: ui16) { return }
-// CHECK-LABEL: spirv.func @integer64
-// CHECK-SAME: i32
-// CHECK-SAME: si32
-// CHECK-SAME: ui32
+// We do not truncate 64-bit types to 32-bit ones.
+// CHECK-LABEL: func.func @integer64
+// CHECK-SAME: i64
+// CHECK-SAME: si64
+// CHECK-SAME: ui64
// NOEMU-LABEL: func @integer64
// NOEMU-SAME: i64
// NOEMU-SAME: si64
@@ -131,13 +133,13 @@ module attributes {
// CHECK-LABEL: spirv.func @float16
// CHECK-SAME: f32
-// NOEMU-LABEL: func @float16
+// NOEMU-LABEL: func.func @float16
// NOEMU-SAME: f16
func.func @float16(%arg0: f16) { return }
-// CHECK-LABEL: spirv.func @float64
-// CHECK-SAME: f32
-// NOEMU-LABEL: func @float64
+// CHECK-LABEL: func.func @float64
+// CHECK-SAME: f64
+// NOEMU-LABEL: func.func @float64
// NOEMU-SAME: f64
func.func @float64(%arg0: f64) { return }
@@ -184,7 +186,7 @@ func.func @bf16_type(%arg0: bf16) { return }
//===----------------------------------------------------------------------===//
// Check that capabilities for scalar types affects vector types too: no special
-// capabilities available means using turning element types to 32-bit.
+// capabilities available means widening element types to 32-bit.
module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
} {
@@ -192,19 +194,15 @@ module attributes {
// CHECK-LABEL: spirv.func @int_vector
// CHECK-SAME: vector<2xi32>
// CHECK-SAME: vector<3xsi32>
-// CHECK-SAME: vector<4xui32>
func.func @int_vector(
%arg0: vector<2xi8>,
- %arg1: vector<3xsi16>,
- %arg2: vector<4xui64>
+ %arg1: vector<3xsi16>
) { return }
// CHECK-LABEL: spirv.func @float_vector
// CHECK-SAME: vector<2xf32>
-// CHECK-SAME: vector<3xf32>
func.func @float_vector(
- %arg0: vector<2xf16>,
- %arg1: vector<3xf64>
+ %arg0: vector<2xf16>
) { return }
// CHECK-LABEL: spirv.func @one_element_vector
@@ -389,33 +387,35 @@ func.func @memref_16bit_Input(%arg3: memref<16xf16, #spirv.storage_class<Input>>
// NOEMU-SAME: memref<16xf16, #spirv.storage_class<Output>>
func.func @memref_16bit_Output(%arg4: memref<16xf16, #spirv.storage_class<Output>>) { return }
-// CHECK-LABEL: spirv.func @memref_64bit_StorageBuffer
-// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<32 x i32, stride=4> [0])>, StorageBuffer>
-// NOEMU-LABEL: func @memref_64bit_StorageBuffer
+// We do not truncate i64 to i32.
+
+// CHECK-LABEL: func.func @memref_64bit_StorageBuffer
+// CHECK-SAME: memref<16xi64, #spirv.storage_class<StorageBuffer>>
+// NOEMU-LABEL: func.func @memref_64bit_StorageBuffer
// NOEMU-SAME: memref<16xi64, #spirv.storage_class<StorageBuffer>>
func.func @memref_64bit_StorageBuffer(%arg0: memref<16xi64, #spirv.storage_class<StorageBuffer>>) { return }
-// CHECK-LABEL: spirv.func @memref_64bit_Uniform
-// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<32 x si32, stride=4> [0])>, Uniform>
-// NOEMU-LABEL: func @memref_64bit_Uniform
+// CHECK-LABEL: func.func @memref_64bit_Uniform
+// CHECK-SAME: memref<16xsi64, #spirv.storage_class<Uniform>>
+// NOEMU-LABEL: func.func @memref_64bit_Uniform
// NOEMU-SAME: memref<16xsi64, #spirv.storage_class<Uniform>>
func.func @memref_64bit_Uniform(%arg0: memref<16xsi64, #spirv.storage_class<Uniform>>) { return }
-// CHECK-LABEL: spirv.func @memref_64bit_PushConstant
-// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<32 x ui32, stride=4> [0])>, PushConstant>
-// NOEMU-LABEL: func @memref_64bit_PushConstant
+// CHECK-LABEL: func.func @memref_64bit_PushConstant
+// CHECK-SAME: memref<16xui64, #spirv.storage_class<PushConstant>>
+// NOEMU-LABEL: func.func @memref_64bit_PushConstant
// NOEMU-SAME: memref<16xui64, #spirv.storage_class<PushConstant>>
func.func @memref_64bit_PushConstant(%arg0: memref<16xui64, #spirv.storage_class<PushConstant>>) { return }
-// CHECK-LABEL: spirv.func @memref_64bit_Input
-// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<32 x f32>)>, Input>
-// NOEMU-LABEL: func @memref_64bit_Input
+// CHECK-LABEL: func.func @memref_64bit_Input
+// CHECK-SAME: memref<16xf64, #spirv.storage_class<Input>>
+// NOEMU-LABEL: func.func @memref_64bit_Input
// NOEMU-SAME: memref<16xf64, #spirv.storage_class<Input>>
func.func @memref_64bit_Input(%arg3: memref<16xf64, #spirv.storage_class<Input>>) { return }
-// CHECK-LABEL: spirv.func @memref_64bit_Output
-// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<32 x f32>)>, Output>
-// NOEMU-LABEL: func @memref_64bit_Output
+// CHECK-LABEL: func.func @memref_64bit_Output
+// CHECK-SAME: memref<16xf64, #spirv.storage_class<Output>>
+// NOEMU-LABEL: func.func @memref_64bit_Output
// NOEMU-SAME: memref<16xf64, #spirv.storage_class<Output>>
func.func @memref_64bit_Output(%arg4: memref<16xf64, #spirv.storage_class<Output>>) { return }
@@ -791,9 +791,7 @@ module attributes {
// CHECK-SAME: !spirv.array<32 x i32>
// CHECK-SAME: !spirv.array<32 x i32>
// CHECK-SAME: !spirv.array<32 x i32>
-// CHECK-SAME: !spirv.array<32 x i32>
func.func @int_tensor_types(
- %arg0: tensor<8x4xi64>,
%arg1: tensor<8x4xi32>,
%arg2: tensor<8x4xi16>,
%arg3: tensor<8x4xi8>
@@ -802,9 +800,7 @@ func.func @int_tensor_types(
// CHECK-LABEL: spirv.func @float_tensor_types
// CHECK-SAME: !spirv.array<32 x f32>
// CHECK-SAME: !spirv.array<32 x f32>
-// CHECK-SAME: !spirv.array<32 x f32>
func.func @float_tensor_types(
- %arg0: tensor<8x4xf64>,
%arg1: tensor<8x4xf32>,
%arg2: tensor<8x4xf16>
) { return }
More information about the Mlir-commits
mailing list