[Mlir-commits] [mlir] e0ea1fc - [mlir][spirv] Fix capability check for 64-bit element types
Lei Zhang
llvmlistbot at llvm.org
Wed May 25 07:57:58 PDT 2022
Author: Lei Zhang
Date: 2022-05-25T10:57:31-04:00
New Revision: e0ea1fc6f8aa5c51061dced0f86c4fd25e3e9333
URL: https://github.com/llvm/llvm-project/commit/e0ea1fc6f8aa5c51061dced0f86c4fd25e3e9333
DIFF: https://github.com/llvm/llvm-project/commit/e0ea1fc6f8aa5c51061dced0f86c4fd25e3e9333.diff
LOG: [mlir][spirv] Fix capability check for 64-bit element types
Using 64-bit integer/float type in interface storage classes would
require Int64/Float64 capability, per the Vulkan spec:
```
shaderInt64 specifies whether 64-bit integers (signed and unsigned) are
supported in shader code. If this feature is not enabled, 64-bit integer
types must not be used in shader code. This also specifies whether
shader modules can declare the Int64 capability. Declaring and using
64-bit integers is enabled for all storage classes that SPIR-V allows
with the Int64 capability.
```
This is different from, say, 16-bit element types, where:
```
shaderInt16 specifies whether 16-bit integers (signed and unsigned) are
supported in shader code. If this feature is not enabled, 16-bit integer
types must not be used in shader code. This also specifies whether
shader modules can declare the Int16 capability. However, this only
enables a subset of the storage classes that SPIR-V allows for the Int16
SPIR-V capability: Declaring and using 16-bit integers in the Private,
Workgroup (for non-Block variables), and Function storage classes is
enabled, while declaring them in the interface storage classes (e.g.,
UniformConstant, Uniform, StorageBuffer, Input, Output, and
PushConstant) is not enabled.
```
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D126256
Added:
Modified:
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index b66f569c352d9..494f32925315b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -549,14 +549,17 @@ void ScalarType::getCapabilities(
static const Capability caps[] = {Capability::cap8}; \
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
capabilities.push_back(ref); \
- } else if (bitwidth == 16) { \
+ return; \
+ } \
+ if (bitwidth == 16) { \
static const Capability caps[] = {Capability::cap16}; \
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
capabilities.push_back(ref); \
+ return; \
} \
- /* No requirements for other bitwidths */ \
- return; \
- }
+ /* For 64-bit integers/floats, Int64/Float64 enables support for all */ \
+ /* storage classes. Fall through to the next section. */ \
+ } break
// This part only handles the cases where special bitwidths appearing in
// interface storage classes.
@@ -573,8 +576,9 @@ void ScalarType::getCapabilities(
static const Capability caps[] = {Capability::StorageInputOutput16};
ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
capabilities.push_back(ref);
+ return;
}
- return;
+ break;
}
default:
break;
@@ -594,22 +598,22 @@ void ScalarType::getCapabilities(
if (auto intType = dyn_cast<IntegerType>()) {
switch (bitwidth) {
- case 32:
- case 1:
- break;
WIDTH_CASE(Int, 8);
WIDTH_CASE(Int, 16);
WIDTH_CASE(Int, 64);
+ case 1:
+ case 32:
+ break;
default:
llvm_unreachable("invalid bitwidth to getCapabilities");
}
} else {
assert(isa<FloatType>());
switch (bitwidth) {
- case 32:
- break;
WIDTH_CASE(Float, 16);
WIDTH_CASE(Float, 64);
+ case 32:
+ break;
default:
llvm_unreachable("invalid bitwidth to getCapabilities");
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 68cb4ee225546..bad8922a6fb76 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -413,7 +413,7 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
}
int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
- auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize;
+ auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
@@ -455,13 +455,6 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
if (!arrayElemType)
return nullptr;
- Optional<int64_t> elementSize = getTypeNumBytes(options, elementType);
- if (!elementSize) {
- LLVM_DEBUG(llvm::dbgs()
- << type << " illegal: cannot deduce element size\n");
- return nullptr;
- }
-
Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
if (!arrayElemSize) {
LLVM_DEBUG(llvm::dbgs()
@@ -482,7 +475,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
- auto arrayElemCount = *memrefSize / *elementSize;
+ auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index cce9e2749de3d..f7a213b091928 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -312,53 +312,83 @@ module attributes {
func.func @memref_1bit_type(%arg0: memref<5xi1>) { return }
// CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i32, stride=4> [0])>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x i32, stride=4> [0])>, StorageBuffer>
// NOEMU-LABEL: func @memref_8bit_StorageBuffer
// NOEMU-SAME: memref<16xi8>
func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }
// CHECK-LABEL: spv.func @memref_8bit_Uniform
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x si32, stride=4> [0])>, Uniform>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x si32, stride=4> [0])>, Uniform>
// NOEMU-LABEL: func @memref_8bit_Uniform
// NOEMU-SAME: memref<16xsi8, 4>
func.func @memref_8bit_Uniform(%arg0: memref<16xsi8, 4>) { return }
// CHECK-LABEL: spv.func @memref_8bit_PushConstant
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x ui32, stride=4> [0])>, PushConstant>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x ui32, stride=4> [0])>, PushConstant>
// NOEMU-LABEL: func @memref_8bit_PushConstant
// NOEMU-SAME: memref<16xui8, 7>
func.func @memref_8bit_PushConstant(%arg0: memref<16xui8, 7>) { return }
// CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i32, stride=4> [0])>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x i32, stride=4> [0])>, StorageBuffer>
// NOEMU-LABEL: func @memref_16bit_StorageBuffer
// NOEMU-SAME: memref<16xi16>
func.func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, 0>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Uniform
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x si32, stride=4> [0])>, Uniform>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x si32, stride=4> [0])>, Uniform>
// NOEMU-LABEL: func @memref_16bit_Uniform
// NOEMU-SAME: memref<16xsi16, 4>
func.func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return }
// CHECK-LABEL: spv.func @memref_16bit_PushConstant
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x ui32, stride=4> [0])>, PushConstant>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x ui32, stride=4> [0])>, PushConstant>
// NOEMU-LABEL: func @memref_16bit_PushConstant
// NOEMU-SAME: memref<16xui16, 7>
func.func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Input
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32>)>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x f32>)>, Input>
// NOEMU-LABEL: func @memref_16bit_Input
// NOEMU-SAME: memref<16xf16, 9>
func.func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Output
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32>)>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x f32>)>, Output>
// NOEMU-LABEL: func @memref_16bit_Output
// NOEMU-SAME: memref<16xf16, 10>
func.func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return }
+// CHECK-LABEL: spv.func @memref_64bit_StorageBuffer
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x i32, stride=4> [0])>, StorageBuffer>
+// NOEMU-LABEL: func @memref_64bit_StorageBuffer
+// NOEMU-SAME: memref<16xi64>
+func.func @memref_64bit_StorageBuffer(%arg0: memref<16xi64, 0>) { return }
+
+// CHECK-LABEL: spv.func @memref_64bit_Uniform
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x si32, stride=4> [0])>, Uniform>
+// NOEMU-LABEL: func @memref_64bit_Uniform
+// NOEMU-SAME: memref<16xsi64, 4>
+func.func @memref_64bit_Uniform(%arg0: memref<16xsi64, 4>) { return }
+
+// CHECK-LABEL: spv.func @memref_64bit_PushConstant
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x ui32, stride=4> [0])>, PushConstant>
+// NOEMU-LABEL: func @memref_64bit_PushConstant
+// NOEMU-SAME: memref<16xui64, 7>
+func.func @memref_64bit_PushConstant(%arg0: memref<16xui64, 7>) { return }
+
+// CHECK-LABEL: spv.func @memref_64bit_Input
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x f32>)>, Input>
+// NOEMU-LABEL: func @memref_64bit_Input
+// NOEMU-SAME: memref<16xf64, 9>
+func.func @memref_64bit_Input(%arg3: memref<16xf64, 9>) { return }
+
+// CHECK-LABEL: spv.func @memref_64bit_Output
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x f32>)>, Output>
+// NOEMU-LABEL: func @memref_64bit_Output
+// NOEMU-SAME: memref<16xf64, 10>
+func.func @memref_64bit_Output(%arg4: memref<16xf64, 10>) { return }
+
} // end module
// -----
@@ -368,7 +398,7 @@ func.func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return }
// and extension is available.
module attributes {
spv.target_env = #spv.target_env<
- #spv.vce<v1.0, [StoragePushConstant8, StoragePushConstant16],
+ #spv.vce<v1.0, [StoragePushConstant8, StoragePushConstant16, Int64, Float64],
[SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>, {}>
} {
@@ -389,6 +419,17 @@ func.func @memref_16bit_PushConstant(
%arg1: memref<16xf16, 7>
) { return }
+// CHECK-LABEL: spv.func @memref_64bit_PushConstant
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, PushConstant>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, PushConstant>
+// NOEMU-LABEL: spv.func @memref_64bit_PushConstant
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, PushConstant>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, PushConstant>
+func.func @memref_64bit_PushConstant(
+ %arg0: memref<16xi64, 7>,
+ %arg1: memref<16xf64, 7>
+) { return }
+
} // end module
// -----
@@ -398,7 +439,7 @@ func.func @memref_16bit_PushConstant(
// and extension is available.
module attributes {
spv.target_env = #spv.target_env<
- #spv.vce<v1.0, [StorageBuffer8BitAccess, StorageBuffer16BitAccess],
+ #spv.vce<v1.0, [StorageBuffer8BitAccess, StorageBuffer16BitAccess, Int64, Float64],
[SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>, {}>
} {
@@ -419,6 +460,17 @@ func.func @memref_16bit_StorageBuffer(
%arg1: memref<16xf16, 0>
) { return }
+// CHECK-LABEL: spv.func @memref_64bit_StorageBuffer
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, StorageBuffer>
+// NOEMU-LABEL: spv.func @memref_64bit_StorageBuffer
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, StorageBuffer>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, StorageBuffer>
+func.func @memref_64bit_StorageBuffer(
+ %arg0: memref<16xi64, 0>,
+ %arg1: memref<16xf64, 0>
+) { return }
+
} // end module
// -----
@@ -428,7 +480,7 @@ func.func @memref_16bit_StorageBuffer(
// and extension is available.
module attributes {
spv.target_env = #spv.target_env<
- #spv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16],
+ #spv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Float64],
[SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>, {}>
} {
@@ -449,6 +501,17 @@ func.func @memref_16bit_Uniform(
%arg1: memref<16xf16, 4>
) { return }
+// CHECK-LABEL: spv.func @memref_64bit_Uniform
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, Uniform>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, Uniform>
+// NOEMU-LABEL: spv.func @memref_64bit_Uniform
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, Uniform>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, Uniform>
+func.func @memref_64bit_Uniform(
+ %arg0: memref<16xi64, 4>,
+ %arg1: memref<16xf64, 4>
+) { return }
+
} // end module
// -----
@@ -458,7 +521,7 @@ func.func @memref_16bit_Uniform(
// and extension is available.
module attributes {
spv.target_env = #spv.target_env<
- #spv.vce<v1.0, [StorageInputOutput16], [SPV_KHR_16bit_storage]>, {}>
+ #spv.vce<v1.0, [StorageInputOutput16, Int64, Float64], [SPV_KHR_16bit_storage]>, {}>
} {
// CHECK-LABEL: spv.func @memref_16bit_Input
@@ -473,6 +536,28 @@ func.func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16>)>, Output>
func.func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return }
+// CHECK-LABEL: spv.func @memref_64bit_Input
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64>)>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64>)>, Input>
+// NOEMU-LABEL: spv.func @memref_64bit_Input
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64>)>, Input>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64>)>, Input>
+func.func @memref_64bit_Input(
+ %arg0: memref<16xi64, 9>,
+ %arg1: memref<16xf64, 9>
+) { return }
+
+// CHECK-LABEL: spv.func @memref_64bit_Output
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64>)>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64>)>, Output>
+// NOEMU-LABEL: spv.func @memref_64bit_Output
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64>)>, Output>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64>)>, Output>
+func.func @memref_64bit_Output(
+ %arg0: memref<16xi64, 10>,
+ %arg1: memref<16xf64, 10>
+) { return }
+
} // end module
// -----
diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index 498b0f9977647..692d70db83633 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -40,7 +40,7 @@ module attributes {
}
// CHECK: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}}
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<20 x i32>)>, Workgroup>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<10 x i32>)>, Workgroup>
// CHECK: func @alloc_dealloc_workgroup_mem
// CHECK: %[[VAR:.+]] = spv.mlir.addressof @__workgroup_mem__0
// CHECK: %[[LOC:.+]] = spv.SDiv
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index fa1f33a8fad76..7d09e1c507af5 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -186,10 +186,10 @@ spv.module Logical GLSL450 attributes {
// Complicated nested types
// * Buffer requires ImageBuffer or SampledBuffer.
// * Rg32f requires StorageImageExtendedFormats.
-// CHECK: requires #spv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Shader, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
+// CHECK: requires #spv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
spv.module Logical GLSL450 attributes {
spv.target_env = #spv.target_env<
- #spv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, ImageBuffer, StorageImageExtendedFormats], []>,
+ #spv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, Int64, ImageBuffer, StorageImageExtendedFormats], []>,
{}>
} {
spv.GlobalVariable @data : !spv.ptr<!spv.struct<(i8 [0], f16 [2], i64 [4])>, Uniform>
More information about the Mlir-commits
mailing list