[Mlir-commits] [mlir] 48378a3 - [spirv] Fix bitwidth emulation for Workgroup storage class

Lei Zhang llvmlistbot at llvm.org
Wed Aug 5 11:44:11 PDT 2020


Author: Lei Zhang
Date: 2020-08-05T14:44:03-04:00
New Revision: 48378a32af54af6ae656a3db14dc7c0d975d0f48

URL: https://github.com/llvm/llvm-project/commit/48378a32af54af6ae656a3db14dc7c0d975d0f48
DIFF: https://github.com/llvm/llvm-project/commit/48378a32af54af6ae656a3db14dc7c0d975d0f48.diff

LOG: [spirv] Fix bitwidth emulation for Workgroup storage class

If Int16 is not available, 16-bit integers inside Workgroup storage
class should be emulated via 32-bit integers. This was previously
broken because the capability querying logic was incorrectly
intercepting all storage classes where it meant to only handle
interface storage classes. Adjusted where we return to fix this.

Differential Revision: https://reviews.llvm.org/D85308

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
    mlir/test/Conversion/StandardToSPIRV/alloc.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 93d0c43d669f..583a779408b4 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -772,8 +772,12 @@ void ScalarType::getCapabilities(
       ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));              \
       capabilities.push_back(ref);                                             \
     }                                                                          \
-  } break
+    /* No requirements for other bitwidths */                                  \
+    return;                                                                    \
+  }
 
+  // This part only handles the cases where special bitwidths appearing in
+  // interface storage classes.
   if (storage) {
     switch (*storage) {
       STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
@@ -782,17 +786,17 @@ void ScalarType::getCapabilities(
       STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
                    StorageUniform16);
     case StorageClass::Input:
-    case StorageClass::Output:
+    case StorageClass::Output: {
       if (bitwidth == 16) {
         static const Capability caps[] = {Capability::StorageInputOutput16};
         ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
         capabilities.push_back(ref);
       }
-      break;
+      return;
+    }
     default:
       break;
     }
-    return;
   }
 #undef STORAGE_CASE
 

diff  --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir
index fe4c9d125a26..14ce4699a455 100644
--- a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir
@@ -32,25 +32,34 @@ module attributes {
 
 // -----
 
-// TODO: Uncomment this test when the extension handling correctly
-// converts an i16 type to i32 type and handles the load/stores
-// correctly.
-
-// module attributes {
-//   spv.target_env = #spv.target_env<
-//     #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
-//     {max_compute_workgroup_invocations = 128 : i32,
-//      max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
-//   }
-// {
-//   func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
-//     %0 = alloc() : memref<4x5xi16, 3>
-//     %1 = load %0[%arg0, %arg1] : memref<4x5xi16, 3>
-//     store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3>
-//     dealloc %0 : memref<4x5xi16, 3>
-//     return
-//   }
-// }
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+  }
+{
+  func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
+    %0 = alloc() : memref<4x5xi16, 3>
+    %1 = load %0[%arg0, %arg1] : memref<4x5xi16, 3>
+    store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3>
+    dealloc %0 : memref<4x5xi16, 3>
+    return
+  }
+}
+
+//       CHECK: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
+//  CHECK-SAME:   !spv.ptr<!spv.struct<!spv.array<20 x i32, stride=4>>, Workgroup>
+// CHECK_LABEL: spv.func @alloc_dealloc_workgroup_mem
+//       CHECK:   %[[VAR:.+]] = spv._address_of @__workgroup_mem__0
+//       CHECK:   %[[LOC:.+]] = spv.SDiv
+//       CHECK:   %[[PTR:.+]] = spv.AccessChain %[[VAR]][%{{.+}}, %[[LOC]]]
+//       CHECK:   %{{.+}} = spv.Load "Workgroup" %[[PTR]] : i32
+//       CHECK:   %[[LOC:.+]] = spv.SDiv
+//       CHECK:   %[[PTR:.+]] = spv.AccessChain %[[VAR]][%{{.+}}, %[[LOC]]]
+//       CHECK:   %{{.+}} = spv.AtomicAnd "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spv.ptr<i32, Workgroup>
+//       CHECK:   %{{.+}} = spv.AtomicOr "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spv.ptr<i32, Workgroup>
+
 
 // -----
 


        


More information about the Mlir-commits mailing list