[Mlir-commits] [mlir] 59156ba - [mlir][spirv] Add support for converting memref of vector to SPIR-V

Thomas Raoux llvmlistbot at llvm.org
Thu Jul 30 15:06:07 PDT 2020


Author: Thomas Raoux
Date: 2020-07-30T15:05:40-07:00
New Revision: 59156bad03ffe37558b95fca62b5df4394de280c

URL: https://github.com/llvm/llvm-project/commit/59156bad03ffe37558b95fca62b5df4394de280c
DIFF: https://github.com/llvm/llvm-project/commit/59156bad03ffe37558b95fca62b5df4394de280c.diff

LOG: [mlir][spirv] Add support for converting memref of vector to SPIR-V

This allow declaring buffers and alloc of vectors so that we can support vector
load/store.

Differential Revision: https://reviews.llvm.org/D84982

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
    mlir/test/Conversion/StandardToSPIRV/alloc.mlir
    mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index e59830fcef89..543b23acabeb 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -217,11 +217,15 @@ CHECK_UNSIGNED_OP(spirv::UModOp)
 /// Returns true if the allocations of type `t` can be lowered to SPIR-V.
 static bool isAllocationSupported(MemRefType t) {
   // Currently only support workgroup local memory allocations with static
-  // shape and int or float element type.
-  return t.hasStaticShape() &&
-         SPIRVTypeConverter::getMemorySpaceForStorageClass(
-             spirv::StorageClass::Workgroup) == t.getMemorySpace() &&
-         t.getElementType().isIntOrFloat();
+  // shape and int or float or vector of int or float element type.
+  if (!(t.hasStaticShape() &&
+        SPIRVTypeConverter::getMemorySpaceForStorageClass(
+            spirv::StorageClass::Workgroup) == t.getMemorySpace()))
+    return false;
+  Type elementType = t.getElementType();
+  if (auto vecType = elementType.dyn_cast<VectorType>())
+    elementType = vecType.getElementType();
+  return elementType.isIntOrFloat();
 }
 
 /// Returns the scope to use for atomic operations use for emulating store

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index d31f9c28362a..3d7535f9110e 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -170,7 +170,14 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
       return llvm::None;
     }
     return bitWidth / 8;
-  } else if (auto memRefType = t.dyn_cast<MemRefType>()) {
+  }
+  if (auto vecType = t.dyn_cast<VectorType>()) {
+    auto elementSize = getTypeNumBytes(vecType.getElementType());
+    if (!elementSize)
+      return llvm::None;
+    return vecType.getNumElements() * *elementSize;
+  }
+  if (auto memRefType = t.dyn_cast<MemRefType>()) {
     // TODO: Layout should also be controlled by the ABI attributes. For now
     // using the layout from MemRef.
     int64_t offset;
@@ -343,26 +350,31 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
     return llvm::None;
   }
 
-  auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
-  if (!scalarType) {
-    LLVM_DEBUG(llvm::dbgs()
-               << type << " illegal: cannot convert non-scalar element type\n");
+  Optional<Type> arrayElemType;
+  Type elementType = type.getElementType();
+  if (auto vecType = elementType.dyn_cast<VectorType>()) {
+    arrayElemType = convertVectorType(targetEnv, vecType, storageClass);
+  } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
+    arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
+  } else {
+    LLVM_DEBUG(
+        llvm::dbgs()
+        << type
+        << " unhandled: can only convert scalar or vector element type\n");
     return llvm::None;
   }
-
-  auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
   if (!arrayElemType)
     return llvm::None;
 
-  Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
-  if (!scalarSize) {
+  Optional<int64_t> elementSize = getTypeNumBytes(elementType);
+  if (!elementSize) {
     LLVM_DEBUG(llvm::dbgs()
                << type << " illegal: cannot deduce element size\n");
     return llvm::None;
   }
 
   if (!type.hasStaticShape()) {
-    auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *scalarSize);
+    auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize);
     // Wrap in a struct to satisfy Vulkan interface requirements.
     auto structType = spirv::StructType::get(arrayType, 0);
     return spirv::PointerType::get(structType, *storageClass);
@@ -375,7 +387,7 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
     return llvm::None;
   }
 
-  auto arrayElemCount = *memrefSize / *scalarSize;
+  auto arrayElemCount = *memrefSize / *elementSize;
 
   Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
   if (!arrayElemSize) {

diff  --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir
index 3cbeda1cafb0..fe4c9d125a26 100644
--- a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir
@@ -75,6 +75,30 @@ module attributes {
 //      CHECK: spv.func @two_allocs()
 //      CHECK: spv.Return
 
+// -----
+
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+  }
+{
+  func @two_allocs_vector() {
+    %0 = alloc() : memref<4xvector<4xf32>, 3>
+    %1 = alloc() : memref<2xvector<2xi32>, 3>
+    return
+  }
+}
+
+//  CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
+// CHECK-SAME:   !spv.ptr<!spv.struct<!spv.array<2 x vector<2xi32>, stride=8>>, Workgroup>
+//  CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
+// CHECK-SAME:   !spv.ptr<!spv.struct<!spv.array<4 x vector<4xf32>, stride=16>>, Workgroup>
+//      CHECK: spv.func @two_allocs_vector()
+//      CHECK: spv.Return
+
+
 // -----
 
 module attributes {

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
index b98a20a56c6a..5ea44c18c618 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
@@ -510,6 +510,51 @@ func @dynamic_dim_memref(%arg0: memref<8x?xi32>,
 
 // -----
 
+// Vector types
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: func @memref_vector
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<4 x vector<2xf32>, stride=8> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<4 x vector<4xf32>, stride=16> [0]>, Uniform>
+func @memref_vector(
+    %arg0: memref<4xvector<2xf32>, 0>,
+    %arg1: memref<4xvector<4xf32>, 4>)
+{ return }
+
+// CHECK-LABEL: func @dynamic_dim_memref_vector
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<vector<4xi32>, stride=16> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<vector<2xf32>, stride=8> [0]>, StorageBuffer>
+func @dynamic_dim_memref_vector(%arg0: memref<8x?xvector<4xi32>>,
+                         %arg1: memref<?x?xvector<2xf32>>)
+{ return }
+
+} // end module
+
+// -----
+
+// Vector types, check that sizes not available in SPIR-V are not transformed.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: func @memref_vector_wrong_size
+// CHECK-SAME: memref<4xvector<5xf32>>
+func @memref_vector_wrong_size(
+    %arg0: memref<4xvector<5xf32>, 0>)
+{ return }
+
+} // end module
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Tensor types
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list