[Mlir-commits] [mlir] 5b15fe9 - [mlir][spirv] Only attach struct offset for required storage classes

Lei Zhang llvmlistbot at llvm.org
Tue Apr 13 12:30:37 PDT 2021


Author: Lei Zhang
Date: 2021-04-13T15:30:30-04:00
New Revision: 5b15fe9334b802b928f2f6cfedde31bb8cba72ee

URL: https://github.com/llvm/llvm-project/commit/5b15fe9334b802b928f2f6cfedde31bb8cba72ee
DIFF: https://github.com/llvm/llvm-project/commit/5b15fe9334b802b928f2f6cfedde31bb8cba72ee.diff

LOG: [mlir][spirv] Only attach struct offset for required storage classes

Per the SPIR-V spec "2.16.2. Validation Rules for Shader Capabilities":

  Composite objects in the StorageBuffer, PhysicalStorageBuffer,
  Uniform, and PushConstant Storage Classes must be explicitly
  laid out.

For other cases we don't need to attach the struct offsets.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D100386

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
    mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index ea703112a2da3..9de45e7200bd4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -84,6 +84,30 @@ static LogicalResult checkCapabilityRequirements(
   return success();
 }
 
+/// Returns true if the given `storageClass` needs explicit layout when used in
+/// Shader environments.
+static bool needsExplicitLayout(spirv::StorageClass storageClass) {
+  switch (storageClass) {
+  case spirv::StorageClass::PhysicalStorageBuffer:
+  case spirv::StorageClass::PushConstant:
+  case spirv::StorageClass::StorageBuffer:
+  case spirv::StorageClass::Uniform:
+    return true;
+  default:
+    return false;
+  }
+}
+
+/// Wraps the given `elementType` in a struct and gets the pointer to the
+/// struct. This is used to satisfy Vulkan interface requirements.
+static spirv::PointerType
+wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
+  auto structType = needsExplicitLayout(storageClass)
+                        ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
+                        : spirv::StructType::get(elementType);
+  return spirv::PointerType::get(structType, storageClass);
+}
+
 //===----------------------------------------------------------------------===//
 // Type Conversion
 //===----------------------------------------------------------------------===//
@@ -392,12 +416,7 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
   auto arrayType =
       spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
 
-  // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
-  // workgroup storage class do not need the struct to be laid out explicitly.
-  auto structType = *storageClass == spirv::StorageClass::Workgroup
-                        ? spirv::StructType::get(arrayType)
-                        : spirv::StructType::get(arrayType, 0);
-  return spirv::PointerType::get(structType, *storageClass);
+  return wrapInStructAndGetPointer(arrayType, *storageClass);
 }
 
 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
@@ -452,9 +471,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   if (!type.hasStaticShape()) {
     auto arrayType =
         spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize);
-    // Wrap in a struct to satisfy Vulkan interface requirements.
-    auto structType = spirv::StructType::get(arrayType, 0);
-    return spirv::PointerType::get(structType, *storageClass);
+    return wrapInStructAndGetPointer(arrayType, *storageClass);
   }
 
   Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
@@ -470,12 +487,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   auto arrayType =
       spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
 
-  // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
-  // workgroup storage class do not need the struct to be laid out explicitly.
-  auto structType = *storageClass == spirv::StorageClass::Workgroup
-                        ? spirv::StructType::get(arrayType)
-                        : spirv::StructType::get(arrayType, 0);
-  return spirv::PointerType::get(structType, *storageClass);
+  return wrapInStructAndGetPointer(arrayType, *storageClass);
 }
 
 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,

diff  --git a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
index 3066e5bba34d2..96467fc698c32 100644
--- a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir
@@ -9,7 +9,7 @@ module attributes {
     //       CHECK:   spv.func
     //  CHECK-SAME:     {{%.*}}: f32
     //   CHECK-NOT:     spv.interface_var_abi
-    //  CHECK-SAME:     {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32, stride=4> [0])>, CrossWorkgroup>
+    //  CHECK-SAME:     {{%.*}}: !spv.ptr<!spv.struct<(!spv.array<12 x f32, stride=4>)>, CrossWorkgroup>
     //   CHECK-NOT:     spv.interface_var_abi
     //  CHECK-SAME:     spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}
     gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, 11>) kernel

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
index 1574ba24383ea..aa4999d660a98 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
@@ -337,13 +337,13 @@ func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return }
 func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Input
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4>)>, Input>
 // NOEMU-LABEL: func @memref_16bit_Input
 // NOEMU-SAME: memref<16xf16, 9>
 func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Output
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4> [0])>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32, stride=4>)>, Output>
 // NOEMU-LABEL: func @memref_16bit_Output
 // NOEMU-SAME: memref<16xf16, 10>
 func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return }
@@ -451,15 +451,15 @@ module attributes {
 } {
 
 // CHECK-LABEL: spv.func @memref_16bit_Input
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2>)>, Input>
 // NOEMU-LABEL: spv.func @memref_16bit_Input
-// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2> [0])>, Input>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f16, stride=2>)>, Input>
 func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Output
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2>)>, Output>
 // NOEMU-LABEL: spv.func @memref_16bit_Output
-// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2> [0])>, Output>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16, stride=2>)>, Output>
 func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return }
 
 } // end module
@@ -563,13 +563,13 @@ func @memref_16bit_Uniform(%arg0: memref<?xsi16, 4>) { return }
 func @memref_16bit_PushConstant(%arg0: memref<?xui16, 7>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Input
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4>)>, Input>
 // NOEMU-LABEL: func @memref_16bit_Input
 // NOEMU-SAME: memref<?xf16, 9>
 func @memref_16bit_Input(%arg3: memref<?xf16, 9>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Output
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4>)>, Output>
 // NOEMU-LABEL: func @memref_16bit_Output
 // NOEMU-SAME: memref<?xf16, 10>
 func @memref_16bit_Output(%arg4: memref<?xf16, 10>) { return }


        


More information about the Mlir-commits mailing list