[Mlir-commits] [mlir] 59156ba - [mlir][spirv] Add support for converting memref of vector to SPIR-V
Thomas Raoux
llvmlistbot at llvm.org
Thu Jul 30 15:06:07 PDT 2020
Author: Thomas Raoux
Date: 2020-07-30T15:05:40-07:00
New Revision: 59156bad03ffe37558b95fca62b5df4394de280c
URL: https://github.com/llvm/llvm-project/commit/59156bad03ffe37558b95fca62b5df4394de280c
DIFF: https://github.com/llvm/llvm-project/commit/59156bad03ffe37558b95fca62b5df4394de280c.diff
LOG: [mlir][spirv] Add support for converting memref of vector to SPIR-V
This allow declaring buffers and alloc of vectors so that we can support vector
load/store.
Differential Revision: https://reviews.llvm.org/D84982
Added:
Modified:
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/test/Conversion/StandardToSPIRV/alloc.mlir
mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index e59830fcef89..543b23acabeb 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -217,11 +217,15 @@ CHECK_UNSIGNED_OP(spirv::UModOp)
/// Returns true if the allocations of type `t` can be lowered to SPIR-V.
static bool isAllocationSupported(MemRefType t) {
// Currently only support workgroup local memory allocations with static
- // shape and int or float element type.
- return t.hasStaticShape() &&
- SPIRVTypeConverter::getMemorySpaceForStorageClass(
- spirv::StorageClass::Workgroup) == t.getMemorySpace() &&
- t.getElementType().isIntOrFloat();
+ // shape and int or float or vector of int or float element type.
+ if (!(t.hasStaticShape() &&
+ SPIRVTypeConverter::getMemorySpaceForStorageClass(
+ spirv::StorageClass::Workgroup) == t.getMemorySpace()))
+ return false;
+ Type elementType = t.getElementType();
+ if (auto vecType = elementType.dyn_cast<VectorType>())
+ elementType = vecType.getElementType();
+ return elementType.isIntOrFloat();
}
/// Returns the scope to use for atomic operations use for emulating store
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index d31f9c28362a..3d7535f9110e 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -170,7 +170,14 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
return llvm::None;
}
return bitWidth / 8;
- } else if (auto memRefType = t.dyn_cast<MemRefType>()) {
+ }
+ if (auto vecType = t.dyn_cast<VectorType>()) {
+ auto elementSize = getTypeNumBytes(vecType.getElementType());
+ if (!elementSize)
+ return llvm::None;
+ return vecType.getNumElements() * *elementSize;
+ }
+ if (auto memRefType = t.dyn_cast<MemRefType>()) {
// TODO: Layout should also be controlled by the ABI attributes. For now
// using the layout from MemRef.
int64_t offset;
@@ -343,26 +350,31 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
return llvm::None;
}
- auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
- if (!scalarType) {
- LLVM_DEBUG(llvm::dbgs()
- << type << " illegal: cannot convert non-scalar element type\n");
+ Optional<Type> arrayElemType;
+ Type elementType = type.getElementType();
+ if (auto vecType = elementType.dyn_cast<VectorType>()) {
+ arrayElemType = convertVectorType(targetEnv, vecType, storageClass);
+ } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
+ arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
+ } else {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << type
+ << " unhandled: can only convert scalar or vector element type\n");
return llvm::None;
}
-
- auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
if (!arrayElemType)
return llvm::None;
- Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
- if (!scalarSize) {
+ Optional<int64_t> elementSize = getTypeNumBytes(elementType);
+ if (!elementSize) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: cannot deduce element size\n");
return llvm::None;
}
if (!type.hasStaticShape()) {
- auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *scalarSize);
+ auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize);
// Wrap in a struct to satisfy Vulkan interface requirements.
auto structType = spirv::StructType::get(arrayType, 0);
return spirv::PointerType::get(structType, *storageClass);
@@ -375,7 +387,7 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
return llvm::None;
}
- auto arrayElemCount = *memrefSize / *scalarSize;
+ auto arrayElemCount = *memrefSize / *elementSize;
Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
if (!arrayElemSize) {
diff --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir
index 3cbeda1cafb0..fe4c9d125a26 100644
--- a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir
@@ -75,6 +75,30 @@ module attributes {
// CHECK: spv.func @two_allocs()
// CHECK: spv.Return
+// -----
+
+module attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+ }
+{
+ func @two_allocs_vector() {
+ %0 = alloc() : memref<4xvector<4xf32>, 3>
+ %1 = alloc() : memref<2xvector<2xi32>, 3>
+ return
+ }
+}
+
+// CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<2 x vector<2xi32>, stride=8>>, Workgroup>
+// CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<4 x vector<4xf32>, stride=16>>, Workgroup>
+// CHECK: spv.func @two_allocs_vector()
+// CHECK: spv.Return
+
+
// -----
module attributes {
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
index b98a20a56c6a..5ea44c18c618 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
@@ -510,6 +510,51 @@ func @dynamic_dim_memref(%arg0: memref<8x?xi32>,
// -----
+// Vector types
+module attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: func @memref_vector
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<4 x vector<2xf32>, stride=8> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<4 x vector<4xf32>, stride=16> [0]>, Uniform>
+func @memref_vector(
+ %arg0: memref<4xvector<2xf32>, 0>,
+ %arg1: memref<4xvector<4xf32>, 4>)
+{ return }
+
+// CHECK-LABEL: func @dynamic_dim_memref_vector
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<vector<4xi32>, stride=16> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<vector<2xf32>, stride=8> [0]>, StorageBuffer>
+func @dynamic_dim_memref_vector(%arg0: memref<8x?xvector<4xi32>>,
+ %arg1: memref<?x?xvector<2xf32>>)
+{ return }
+
+} // end module
+
+// -----
+
+// Vector types, check that sizes not available in SPIR-V are not transformed.
+module attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: func @memref_vector_wrong_size
+// CHECK-SAME: memref<4xvector<5xf32>>
+func @memref_vector_wrong_size(
+ %arg0: memref<4xvector<5xf32>, 0>)
+{ return }
+
+} // end module
+
+// -----
+
//===----------------------------------------------------------------------===//
// Tensor types
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list