[Mlir-commits] [mlir] e0ea1fc - [mlir][spirv] Fix capability check for 64-bit element types

Lei Zhang llvmlistbot at llvm.org
Wed May 25 07:57:58 PDT 2022


Author: Lei Zhang
Date: 2022-05-25T10:57:31-04:00
New Revision: e0ea1fc6f8aa5c51061dced0f86c4fd25e3e9333

URL: https://github.com/llvm/llvm-project/commit/e0ea1fc6f8aa5c51061dced0f86c4fd25e3e9333
DIFF: https://github.com/llvm/llvm-project/commit/e0ea1fc6f8aa5c51061dced0f86c4fd25e3e9333.diff

LOG: [mlir][spirv] Fix capability check for 64-bit element types

Using 64-bit integer/float type in interface storage classes would
require Int64/Float64 capability, per the Vulkan spec:

```
shaderInt64 specifies whether 64-bit integers (signed and unsigned) are
supported in shader code. If this feature is not enabled, 64-bit integer
types must not be used in shader code. This also specifies whether
shader modules can declare the Int64 capability. Declaring and using
64-bit integers is enabled for all storage classes that SPIR-V allows
with the Int64 capability.
```

This is different from, say, 16-bit element types, where:

```
shaderInt16 specifies whether 16-bit integers (signed and unsigned) are
supported in shader code. If this feature is not enabled, 16-bit integer
types must not be used in shader code. This also specifies whether
shader modules can declare the Int16 capability. However, this only
enables a subset of the storage classes that SPIR-V allows for the Int16
SPIR-V capability: Declaring and using 16-bit integers in the Private,
Workgroup (for non-Block variables), and Function storage classes is
enabled, while declaring them in the interface storage classes (e.g.,
UniformConstant, Uniform, StorageBuffer, Input, Output, and
PushConstant) is not enabled.
```

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D126256

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
    mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
    mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index b66f569c352d9..494f32925315b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -549,14 +549,17 @@ void ScalarType::getCapabilities(
       static const Capability caps[] = {Capability::cap8};                     \
       ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));              \
       capabilities.push_back(ref);                                             \
-    } else if (bitwidth == 16) {                                               \
+      return;                                                                  \
+    }                                                                          \
+    if (bitwidth == 16) {                                                      \
       static const Capability caps[] = {Capability::cap16};                    \
       ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));              \
       capabilities.push_back(ref);                                             \
+      return;                                                                  \
     }                                                                          \
-    /* No requirements for other bitwidths */                                  \
-    return;                                                                    \
-  }
+    /* For 64-bit integers/floats, Int64/Float64 enables support for all */    \
+    /* storage classes. Fall through to the next section. */                   \
+  } break
 
   // This part only handles the cases where special bitwidths appearing in
   // interface storage classes.
@@ -573,8 +576,9 @@ void ScalarType::getCapabilities(
         static const Capability caps[] = {Capability::StorageInputOutput16};
         ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
         capabilities.push_back(ref);
+        return;
       }
-      return;
+      break;
     }
     default:
       break;
@@ -594,22 +598,22 @@ void ScalarType::getCapabilities(
 
   if (auto intType = dyn_cast<IntegerType>()) {
     switch (bitwidth) {
-    case 32:
-    case 1:
-      break;
       WIDTH_CASE(Int, 8);
       WIDTH_CASE(Int, 16);
       WIDTH_CASE(Int, 64);
+    case 1:
+    case 32:
+      break;
     default:
       llvm_unreachable("invalid bitwidth to getCapabilities");
     }
   } else {
     assert(isa<FloatType>());
     switch (bitwidth) {
-    case 32:
-      break;
       WIDTH_CASE(Float, 16);
       WIDTH_CASE(Float, 64);
+    case 32:
+      break;
     default:
       llvm_unreachable("invalid bitwidth to getCapabilities");
     }

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 68cb4ee225546..bad8922a6fb76 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -413,7 +413,7 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
   }
 
   int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
-  auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize;
+  auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
   int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
 
@@ -455,13 +455,6 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   if (!arrayElemType)
     return nullptr;
 
-  Optional<int64_t> elementSize = getTypeNumBytes(options, elementType);
-  if (!elementSize) {
-    LLVM_DEBUG(llvm::dbgs()
-               << type << " illegal: cannot deduce element size\n");
-    return nullptr;
-  }
-
   Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
   if (!arrayElemSize) {
     LLVM_DEBUG(llvm::dbgs()
@@ -482,7 +475,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
-  auto arrayElemCount = *memrefSize / *elementSize;
+  auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
   int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
 

diff  --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index cce9e2749de3d..f7a213b091928 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -312,53 +312,83 @@ module attributes {
 func.func @memref_1bit_type(%arg0: memref<5xi1>) { return }
 
 // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i32, stride=4> [0])>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_8bit_StorageBuffer
 // NOEMU-SAME: memref<16xi8>
 func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }
 
 // CHECK-LABEL: spv.func @memref_8bit_Uniform
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x si32, stride=4> [0])>, Uniform>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x si32, stride=4> [0])>, Uniform>
 // NOEMU-LABEL: func @memref_8bit_Uniform
 // NOEMU-SAME: memref<16xsi8, 4>
 func.func @memref_8bit_Uniform(%arg0: memref<16xsi8, 4>) { return }
 
 // CHECK-LABEL: spv.func @memref_8bit_PushConstant
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x ui32, stride=4> [0])>, PushConstant>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<4 x ui32, stride=4> [0])>, PushConstant>
 // NOEMU-LABEL: func @memref_8bit_PushConstant
 // NOEMU-SAME: memref<16xui8, 7>
 func.func @memref_8bit_PushConstant(%arg0: memref<16xui8, 7>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i32, stride=4> [0])>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_16bit_StorageBuffer
 // NOEMU-SAME: memref<16xi16>
 func.func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, 0>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Uniform
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x si32, stride=4> [0])>, Uniform>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x si32, stride=4> [0])>, Uniform>
 // NOEMU-LABEL: func @memref_16bit_Uniform
 // NOEMU-SAME: memref<16xsi16, 4>
 func.func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_PushConstant
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x ui32, stride=4> [0])>, PushConstant>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x ui32, stride=4> [0])>, PushConstant>
 // NOEMU-LABEL: func @memref_16bit_PushConstant
 // NOEMU-SAME: memref<16xui16, 7>
 func.func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Input
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32>)>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x f32>)>, Input>
 // NOEMU-LABEL: func @memref_16bit_Input
 // NOEMU-SAME: memref<16xf16, 9>
 func.func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
 
 // CHECK-LABEL: spv.func @memref_16bit_Output
-// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f32>)>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<8 x f32>)>, Output>
 // NOEMU-LABEL: func @memref_16bit_Output
 // NOEMU-SAME: memref<16xf16, 10>
 func.func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return }
 
+// CHECK-LABEL: spv.func @memref_64bit_StorageBuffer
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x i32, stride=4> [0])>, StorageBuffer>
+// NOEMU-LABEL: func @memref_64bit_StorageBuffer
+// NOEMU-SAME: memref<16xi64>
+func.func @memref_64bit_StorageBuffer(%arg0: memref<16xi64, 0>) { return }
+
+// CHECK-LABEL: spv.func @memref_64bit_Uniform
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x si32, stride=4> [0])>, Uniform>
+// NOEMU-LABEL: func @memref_64bit_Uniform
+// NOEMU-SAME: memref<16xsi64, 4>
+func.func @memref_64bit_Uniform(%arg0: memref<16xsi64, 4>) { return }
+
+// CHECK-LABEL: spv.func @memref_64bit_PushConstant
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x ui32, stride=4> [0])>, PushConstant>
+// NOEMU-LABEL: func @memref_64bit_PushConstant
+// NOEMU-SAME: memref<16xui64, 7>
+func.func @memref_64bit_PushConstant(%arg0: memref<16xui64, 7>) { return }
+
+// CHECK-LABEL: spv.func @memref_64bit_Input
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x f32>)>, Input>
+// NOEMU-LABEL: func @memref_64bit_Input
+// NOEMU-SAME: memref<16xf64, 9>
+func.func @memref_64bit_Input(%arg3: memref<16xf64, 9>) { return }
+
+// CHECK-LABEL: spv.func @memref_64bit_Output
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<32 x f32>)>, Output>
+// NOEMU-LABEL: func @memref_64bit_Output
+// NOEMU-SAME: memref<16xf64, 10>
+func.func @memref_64bit_Output(%arg4: memref<16xf64, 10>) { return }
+
 } // end module
 
 // -----
@@ -368,7 +398,7 @@ func.func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return }
 // and extension is available.
 module attributes {
   spv.target_env = #spv.target_env<
-    #spv.vce<v1.0, [StoragePushConstant8, StoragePushConstant16],
+    #spv.vce<v1.0, [StoragePushConstant8, StoragePushConstant16, Int64, Float64],
              [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>, {}>
 } {
 
@@ -389,6 +419,17 @@ func.func @memref_16bit_PushConstant(
   %arg1: memref<16xf16, 7>
 ) { return }
 
+// CHECK-LABEL: spv.func @memref_64bit_PushConstant
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, PushConstant>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, PushConstant>
+// NOEMU-LABEL: spv.func @memref_64bit_PushConstant
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, PushConstant>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, PushConstant>
+func.func @memref_64bit_PushConstant(
+  %arg0: memref<16xi64, 7>,
+  %arg1: memref<16xf64, 7>
+) { return }
+
 } // end module
 
 // -----
@@ -398,7 +439,7 @@ func.func @memref_16bit_PushConstant(
 // and extension is available.
 module attributes {
   spv.target_env = #spv.target_env<
-    #spv.vce<v1.0, [StorageBuffer8BitAccess, StorageBuffer16BitAccess],
+    #spv.vce<v1.0, [StorageBuffer8BitAccess, StorageBuffer16BitAccess, Int64, Float64],
              [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>, {}>
 } {
 
@@ -419,6 +460,17 @@ func.func @memref_16bit_StorageBuffer(
   %arg1: memref<16xf16, 0>
 ) { return }
 
+// CHECK-LABEL: spv.func @memref_64bit_StorageBuffer
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, StorageBuffer>
+// NOEMU-LABEL: spv.func @memref_64bit_StorageBuffer
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, StorageBuffer>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, StorageBuffer>
+func.func @memref_64bit_StorageBuffer(
+  %arg0: memref<16xi64, 0>,
+  %arg1: memref<16xf64, 0>
+) { return }
+
 } // end module
 
 // -----
@@ -428,7 +480,7 @@ func.func @memref_16bit_StorageBuffer(
 // and extension is available.
 module attributes {
   spv.target_env = #spv.target_env<
-    #spv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16],
+    #spv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Float64],
              [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>, {}>
 } {
 
@@ -449,6 +501,17 @@ func.func @memref_16bit_Uniform(
   %arg1: memref<16xf16, 4>
 ) { return }
 
+// CHECK-LABEL: spv.func @memref_64bit_Uniform
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, Uniform>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, Uniform>
+// NOEMU-LABEL: spv.func @memref_64bit_Uniform
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64, stride=8> [0])>, Uniform>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64, stride=8> [0])>, Uniform>
+func.func @memref_64bit_Uniform(
+  %arg0: memref<16xi64, 4>,
+  %arg1: memref<16xf64, 4>
+) { return }
+
 } // end module
 
 // -----
@@ -458,7 +521,7 @@ func.func @memref_16bit_Uniform(
 // and extension is available.
 module attributes {
   spv.target_env = #spv.target_env<
-    #spv.vce<v1.0, [StorageInputOutput16], [SPV_KHR_16bit_storage]>, {}>
+    #spv.vce<v1.0, [StorageInputOutput16, Int64, Float64], [SPV_KHR_16bit_storage]>, {}>
 } {
 
 // CHECK-LABEL: spv.func @memref_16bit_Input
@@ -473,6 +536,28 @@ func.func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return }
 // NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i16>)>, Output>
 func.func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return }
 
+// CHECK-LABEL: spv.func @memref_64bit_Input
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64>)>, Input>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64>)>, Input>
+// NOEMU-LABEL: spv.func @memref_64bit_Input
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64>)>, Input>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64>)>, Input>
+func.func @memref_64bit_Input(
+  %arg0: memref<16xi64, 9>,
+  %arg1: memref<16xf64, 9>
+) { return }
+
+// CHECK-LABEL: spv.func @memref_64bit_Output
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64>)>, Output>
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64>)>, Output>
+// NOEMU-LABEL: spv.func @memref_64bit_Output
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i64>)>, Output>
+// NOEMU-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x f64>)>, Output>
+func.func @memref_64bit_Output(
+  %arg0: memref<16xi64, 10>,
+  %arg1: memref<16xf64, 10>
+) { return }
+
 } // end module
 
 // -----

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index 498b0f9977647..692d70db83633 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -40,7 +40,7 @@ module attributes {
 }
 
 //       CHECK: spv.GlobalVariable @__workgroup_mem__{{[0-9]+}}
-//  CHECK-SAME:   !spv.ptr<!spv.struct<(!spv.array<20 x i32>)>, Workgroup>
+//  CHECK-SAME:   !spv.ptr<!spv.struct<(!spv.array<10 x i32>)>, Workgroup>
 //       CHECK: func @alloc_dealloc_workgroup_mem
 //       CHECK:   %[[VAR:.+]] = spv.mlir.addressof @__workgroup_mem__0
 //       CHECK:   %[[LOC:.+]] = spv.SDiv

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index fa1f33a8fad76..7d09e1c507af5 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -186,10 +186,10 @@ spv.module Logical GLSL450 attributes {
 // Complicated nested types
 // * Buffer requires ImageBuffer or SampledBuffer.
 // * Rg32f requires StorageImageExtendedFormats.
-// CHECK: requires #spv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Shader, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
+// CHECK: requires #spv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Int64, Shader, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
 spv.module Logical GLSL450 attributes {
   spv.target_env = #spv.target_env<
-    #spv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, ImageBuffer, StorageImageExtendedFormats], []>,
+    #spv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, Int64, ImageBuffer, StorageImageExtendedFormats], []>,
     {}>
 } {
   spv.GlobalVariable @data : !spv.ptr<!spv.struct<(i8 [0], f16 [2], i64 [4])>, Uniform>


        


More information about the Mlir-commits mailing list