[Mlir-commits] [mlir] 5b15fe9 - [mlir][spirv] Only attach struct offset for required storage classes
Lei Zhang
llvmlistbot at llvm.org
Tue Apr 13 12:30:37 PDT 2021
Author: Lei Zhang
Date: 2021-04-13T15:30:30-04:00
New Revision: 5b15fe9334b802b928f2f6cfedde31bb8cba72ee
URL: https://github.com/llvm/llvm-project/commit/5b15fe9334b802b928f2f6cfedde31bb8cba72ee
DIFF: https://github.com/llvm/llvm-project/commit/5b15fe9334b802b928f2f6cfedde31bb8cba72ee.diff
LOG: [mlir][spirv] Only attach struct offset for required storage classes
Per the SPIR-V spec "2.16.2. Validation Rules for Shader Capabilities":
Composite objects in the StorageBuffer, PhysicalStorageBuffer,
Uniform, and PushConstant Storage Classes must be explicitly
laid out.
For other cases we don't need to attach the struct offsets.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D100386
Added:
Modified:
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index ea703112a2da3..9de45e7200bd4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -84,6 +84,30 @@ static LogicalResult checkCapabilityRequirements(
return success();
}
+/// Returns true if the given `storageClass` needs explicit layout when used in
+/// Shader environments.
+static bool needsExplicitLayout(spirv::StorageClass storageClass) {
+ switch (storageClass) {
+ case spirv::StorageClass::PhysicalStorageBuffer:
+ case spirv::StorageClass::PushConstant:
+ case spirv::StorageClass::StorageBuffer:
+ case spirv::StorageClass::Uniform:
+ return true;
+ default:
+ return false;
+ }
+}
+
+/// Wraps the given `elementType` in a struct and gets the pointer to the
+/// struct. This is used to satisfy Vulkan interface requirements.
+static spirv::PointerType
+wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
+ auto structType = needsExplicitLayout(storageClass)
+ ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
+ : spirv::StructType::get(elementType);
+ return spirv::PointerType::get(structType, storageClass);
+}
+
//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
@@ -392,12 +416,7 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
auto arrayType =
spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
- // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
- // workgroup storage class do not need the struct to be laid out explicitly.
- auto structType = *storageClass == spirv::StorageClass::Workgroup
- ? spirv::StructType::get(arrayType)
- : spirv::StructType::get(arrayType, 0);
- return spirv::PointerType::get(structType, *storageClass);
+ return wrapInStructAndGetPointer(arrayType, *storageClass);
}
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
@@ -452,9 +471,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
if (!type.hasStaticShape()) {
auto arrayType =
spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize);
- // Wrap in a struct to satisfy Vulkan interface requirements.
- auto structType = spirv::StructType::get(arrayType, 0);
- return spirv::PointerType::get(structType, *storageClass);
+ return wrapInStructAndGetPointer(arrayType, *storageClass);
}
Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
@@ -470,12 +487,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
auto arrayType =
spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
- // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
- // workgroup storage class do not need the struct to be laid out explicitly.
- auto structType = *storageClass == spirv::StorageClass::Workgroup
- ? spirv::StructType::get(arrayType)
- : spirv::StructType::get(arrayType, 0);
- return spirv::PointerType::get(structType, *storageClass);
+ return wrapInStructAndGetPointer(arrayType, *storageClass);
}
SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
diff --git a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
index 3066e5bba34d2..96467fc698c32 100644
--- a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
@@ -9,7 +9,7 @@ module attributes {
// CHECK: spv.func
// CHECK-SAME: {{%.*}}: f32
// CHECK-NOT: spv.interface_var_abi
- // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32, stride=4> [0])>, CrossWorkgroup>
+ // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32, stride=4>)>, CrossWorkgroup>
// CHECK-NOT: spv.interface_var_abi
// CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}
gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, 11>) kernel
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
index 1574ba24383ea..aa4999d660a98 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
@@ -337,13 +337,13 @@ func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return }
func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Input
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4>)>, Input>
// NOEMU-LABEL: func @memref_16bit_Input
// NOEMU-SAME: memref<16xf16, 9>
func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Output
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4>)>, Output>
// NOEMU-LABEL: func @memref_16bit_Output
// NOEMU-SAME: memref<16xf16, 10>
func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return }
@@ -451,15 +451,15 @@ module attributes {
} {
// CHECK-LABEL: spv.func @memref_16bit_Input
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2>)>, Input>
// NOEMU-LABEL: spv.func @memref_16bit_Input
-// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, Input>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2>)>, Input>
func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Output
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2>)>, Output>
// NOEMU-LABEL: spv.func @memref_16bit_Output
-// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Output>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2>)>, Output>
func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return }
} // end module
@@ -563,13 +563,13 @@ func @memref_16bit_Uniform(%arg0: memref<?xsi16, 4>) { return }
func @memref_16bit_PushConstant(%arg0: memref<?xui16, 7>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Input
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4>)>, Input>
// NOEMU-LABEL: func @memref_16bit_Input
// NOEMU-SAME: memref<?xf16, 9>
func @memref_16bit_Input(%arg3: memref<?xf16, 9>) { return }
// CHECK-LABEL: spv.func @memref_16bit_Output
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4>)>, Output>
// NOEMU-LABEL: func @memref_16bit_Output
// NOEMU-SAME: memref<?xf16, 10>
func @memref_16bit_Output(%arg4: memref<?xf16, 10>) { return }
More information about the Mlir-commits
mailing list