[Mlir-commits] [mlir] 12ce9fd - [mlir][spirv] Add client-api option to -convert-spirv-to-llvm
Victor Perez
llvmlistbot at llvm.org
Thu Aug 24 01:49:04 PDT 2023
Author: Victor Perez
Date: 2023-08-24T09:48:36+01:00
New Revision: 12ce9fd1248c6321b343601c1a2468ac7e00c9da
URL: https://github.com/llvm/llvm-project/commit/12ce9fd1248c6321b343601c1a2468ac7e00c9da
DIFF: https://github.com/llvm/llvm-project/commit/12ce9fd1248c6321b343601c1a2468ac7e00c9da.diff
LOG: [mlir][spirv] Add client-api option to -convert-spirv-to-llvm
Option to express that `spirv` StorageClasses should be mapped to LLVM
address spaces in the conversion process. This mapping will be
client-dependent.
The client API cannot be taken from the code as more than one module
could be present, resulting in more than one VCE triple and different
StorageClass to address space mappings. This information would not be
available during type conversion.
A specific mapping for the OpenCL client is defined, based on [the
OpenCL Extended Instruction
Set](https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_binary_form)
and [this
mapping](https://github.com/llvm/llvm-project/blob/3edd338a6407d9410f6a283c5dc32ba676ac0b8f/clang/lib/Basic/Targets/SPIR.h#L27).
Signed-off-by: Victor Perez <victor.perez at codeplay.com>
Reviewed By: antiagainst, kuhar
Differential Revision: https://reviews.llvm.org/D158627
Added:
mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping-unsupported.mlir
mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 3c2962ab86f631..cb06b917b6a9b9 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -954,7 +954,17 @@ def ConvertSPIRVToLLVMPass : Pass<"convert-spirv-to-llvm", "ModuleOp"> {
let options = [
Option<"useOpaquePointers", "use-opaque-pointers", "bool",
/*default=*/"true", "Generate LLVM IR using opaque pointers "
- "instead of typed pointers">
+ "instead of typed pointers">,
+ Option<"clientAPI", "client-api", "::mlir::spirv::ClientAPI",
+ /*default=*/"::mlir::spirv::ClientAPI::Unknown",
+ "Derive StorageClass to address space mapping from the client API",
+ [{::llvm::cl::values(
+ clEnumValN(::mlir::spirv::ClientAPI::Unknown, "Unknown", "Unknown (default)"),
+ clEnumValN(::mlir::spirv::ClientAPI::Metal, "Metal", "Metal"),
+ clEnumValN(::mlir::spirv::ClientAPI::OpenCL, "OpenCL", "OpenCL"),
+ clEnumValN(::mlir::spirv::ClientAPI::Vulkan, "Vulkan", "Vulkan"),
+ clEnumValN(::mlir::spirv::ClientAPI::WebGPU, "WebGPU", "WebGPU")
+ )}]>,
];
}
diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
index 74b14bdf2efb13..84935d19670437 100644
--- a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
+++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
@@ -15,6 +15,8 @@
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+
namespace mlir {
class LLVMTypeConverter;
class MLIRContext;
@@ -37,11 +39,16 @@ class SPIRVToLLVMConversion : public OpConversionPattern<SPIRVOp> {
void encodeBindAttribute(ModuleOp module);
/// Populates type conversions with additional SPIR-V types.
-void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter);
+void populateSPIRVToLLVMTypeConversion(
+ LLVMTypeConverter &typeConverter,
+ spirv::ClientAPI clientAPIForAddressSpaceMapping =
+ spirv::ClientAPI::Unknown);
/// Populates the given list with patterns that convert from SPIR-V to LLVM.
-void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter,
- RewritePatternSet &patterns);
+void populateSPIRVToLLVMConversionPatterns(
+ LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
+ spirv::ClientAPI clientAPIForAddressSpaceMapping =
+ spirv::ClientAPI::Unknown);
/// Populates the given list with patterns for function conversion from SPIR-V
/// to LLVM.
diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h
index c845a745e89d49..d085b924ecde00 100644
--- a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h
+++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h
@@ -15,6 +15,8 @@
#include <memory>
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+
namespace mlir {
class Pass;
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 1d32e6e55f6ae4..60f34f413f587d 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -29,6 +29,12 @@
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// Constants
+//===----------------------------------------------------------------------===//
+
+constexpr unsigned defaultAddressSpace = 0;
+
//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//
@@ -271,12 +277,49 @@ static std::optional<Type> convertArrayType(spirv::ArrayType type,
return LLVM::LLVMArrayType::get(llvmElementType, numElements);
}
+static unsigned mapToOpenCLAddressSpace(spirv::StorageClass storageClass) {
+ // Based on
+ // https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_binary_form
+ // and clang/lib/Basic/Targets/SPIR.h.
+ switch (storageClass) {
+#define STORAGE_SPACE_MAP(storage, space) \
+ case spirv::StorageClass::storage: \
+ return space;
+ STORAGE_SPACE_MAP(Function, 0)
+ STORAGE_SPACE_MAP(CrossWorkgroup, 1)
+ STORAGE_SPACE_MAP(Input, 1)
+ STORAGE_SPACE_MAP(UniformConstant, 2)
+ STORAGE_SPACE_MAP(Workgroup, 3)
+ STORAGE_SPACE_MAP(Generic, 4)
+ STORAGE_SPACE_MAP(DeviceOnlyINTEL, 5)
+ STORAGE_SPACE_MAP(HostOnlyINTEL, 6)
+#undef STORAGE_SPACE_MAP
+ default:
+ return defaultAddressSpace;
+ }
+}
+
+static unsigned mapToAddressSpace(spirv::ClientAPI clientAPI,
+ spirv::StorageClass storageClass) {
+ switch (clientAPI) {
+#define CLIENT_MAP(client, storage) \
+ case spirv::ClientAPI::client: \
+ return mapTo##client##AddressSpace(storage);
+ CLIENT_MAP(OpenCL, storageClass)
+#undef CLIENT_MAP
+ default:
+ return defaultAddressSpace;
+ }
+}
+
/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
/// modelled at the moment.
static Type convertPointerType(spirv::PointerType type,
- LLVMTypeConverter &converter) {
+ LLVMTypeConverter &converter,
+ spirv::ClientAPI clientAPI) {
auto pointeeType = converter.convertType(type.getPointeeType());
- return converter.getPointerType(pointeeType);
+ unsigned addressSpace = mapToAddressSpace(clientAPI, type.getStorageClass());
+ return converter.getPointerType(pointeeType, addressSpace);
}
/// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
@@ -734,7 +777,11 @@ class ExecutionModePattern
class GlobalVariablePattern
: public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
public:
- using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion;
+ template <typename... Args>
+ GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
+ : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
+ std::forward<Args>(args)...),
+ clientAPI(clientAPI) {}
LogicalResult
matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
@@ -779,7 +826,7 @@ class GlobalVariablePattern
: LLVM::Linkage::External;
auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
- /*alignment=*/0);
+ /*alignment=*/0, mapToAddressSpace(clientAPI, storageClass));
// Attach location attribute if applicable
if (op.getLocationAttr())
@@ -787,6 +834,9 @@ class GlobalVariablePattern
return success();
}
+
+private:
+ spirv::ClientAPI clientAPI;
};
/// Converts SPIR-V cast ops that do not have straightforward LLVM
@@ -1494,12 +1544,13 @@ class VectorShufflePattern
// Pattern population
//===----------------------------------------------------------------------===//
-void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
+void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter,
+ spirv::ClientAPI clientAPI) {
typeConverter.addConversion([&](spirv::ArrayType type) {
return convertArrayType(type, typeConverter);
});
- typeConverter.addConversion([&](spirv::PointerType type) {
- return convertPointerType(type, typeConverter);
+ typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
+ return convertPointerType(type, typeConverter, clientAPI);
});
typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
return convertRuntimeArrayType(type, typeConverter);
@@ -1510,7 +1561,8 @@ void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
}
void mlir::populateSPIRVToLLVMConversionPatterns(
- LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
+ LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
+ spirv::ClientAPI clientAPI) {
patterns.add<
// Arithmetic ops
DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
@@ -1605,9 +1657,8 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
NotPattern<spirv::LogicalNotOp>,
// Memory ops
- AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
- LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
- VariablePattern,
+ AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
+ LoadStorePattern<spirv::StoreOp>, VariablePattern,
// Miscellaneous ops
CompositeExtractPattern, CompositeInsertPattern,
@@ -1622,6 +1673,9 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
// Return ops
ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
+
+ patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
+ typeConverter);
}
void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp
index 263276ef1b9b2d..40798e9eb9dcba 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp
@@ -16,6 +16,7 @@
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -50,16 +51,22 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
RewritePatternSet patterns(context);
- populateSPIRVToLLVMTypeConversion(converter);
+ populateSPIRVToLLVMTypeConversion(converter, clientAPI);
populateSPIRVToLLVMModuleConversionPatterns(converter, patterns);
- populateSPIRVToLLVMConversionPatterns(converter, patterns);
+ populateSPIRVToLLVMConversionPatterns(converter, patterns, clientAPI);
populateSPIRVToLLVMFunctionConversionPatterns(converter, patterns);
ConversionTarget target(*context);
target.addIllegalDialect<spirv::SPIRVDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();
+ if (clientAPI != spirv::ClientAPI::OpenCL &&
+ clientAPI != spirv::ClientAPI::Unknown)
+ getOperation()->emitWarning()
+ << "address space mapping for client '"
+ << spirv::stringifyClientAPI(clientAPI) << "' not implemented";
+
// Set `ModuleOp` as legal for `spirv.module` conversion.
target.addLegalOp<ModuleOp>();
if (failed(applyPartialConversion(module, target, std::move(patterns))))
diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping-unsupported.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping-unsupported.mlir
new file mode 100644
index 00000000000000..626637386cca09
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping-unsupported.mlir
@@ -0,0 +1,5 @@
+// RUN: mlir-opt -convert-spirv-to-llvm='client-api=Metal' -verify-diagnostics %s
+// RUN: mlir-opt -convert-spirv-to-llvm='client-api=Vulkan' -verify-diagnostics %s
+// RUN: mlir-opt -convert-spirv-to-llvm='client-api=WebGPU' -verify-diagnostics %s
+
+module {} // expected-warning-re {{address space mapping for client '{{.*}}' not implemented}}
diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
new file mode 100644
index 00000000000000..989ada93cf36ee
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir
@@ -0,0 +1,95 @@
+// RUN: mlir-opt -convert-spirv-to-llvm='use-opaque-pointers=1' -verify-diagnostics %s | FileCheck %s --check-prefixes=CHECK-UNKNOWN,CHECK-ALL
+// RUN: mlir-opt -convert-spirv-to-llvm='use-opaque-pointers=1 client-api=OpenCL' -verify-diagnostics %s | FileCheck %s --check-prefixes=CHECK-OPENCL,CHECK-ALL
+
+// CHECK-OPENCL: llvm.func @pointerUniformConstant(!llvm.ptr<2>)
+// CHECK-UNKNOWN: llvm.func @pointerUniformConstant(!llvm.ptr)
+spirv.func @pointerUniformConstant(!spirv.ptr<i1, UniformConstant>) "None"
+
+// CHECK-OPENCL: llvm.mlir.global external constant @varUniformConstant() {addr_space = 2 : i32} : i1
+// CHECK-UNKNOWN: llvm.mlir.global external constant @varUniformConstant() {addr_space = 0 : i32} : i1
+spirv.GlobalVariable @varUniformConstant : !spirv.ptr<i1, UniformConstant>
+
+// CHECK-OPENCL: llvm.func @pointerInput(!llvm.ptr<1>)
+// CHECK-UNKNOWN: llvm.func @pointerInput(!llvm.ptr)
+spirv.func @pointerInput(!spirv.ptr<i1, Input>) "None"
+
+// CHECK-OPENCL: llvm.mlir.global external constant @varInput() {addr_space = 1 : i32} : i1
+// CHECK-UNKNOWN: llvm.mlir.global external constant @varInput() {addr_space = 0 : i32} : i1
+spirv.GlobalVariable @varInput : !spirv.ptr<i1, Input>
+
+// CHECK-ALL: llvm.func @pointerUniform(!llvm.ptr)
+spirv.func @pointerUniform(!spirv.ptr<i1, Uniform>) "None"
+
+// CHECK-ALL: llvm.func @pointerOutput(!llvm.ptr)
+spirv.func @pointerOutput(!spirv.ptr<i1, Output>) "None"
+
+// CHECK-ALL: llvm.mlir.global external @varOutput() {addr_space = 0 : i32} : i1
+spirv.GlobalVariable @varOutput : !spirv.ptr<i1, Output>
+
+// CHECK-OPENCL: llvm.func @pointerWorkgroup(!llvm.ptr<3>)
+// CHECK-UNKNOWN: llvm.func @pointerWorkgroup(!llvm.ptr)
+spirv.func @pointerWorkgroup(!spirv.ptr<i1, Workgroup>) "None"
+
+// CHECK-OPENCL: llvm.func @pointerCrossWorkgroup(!llvm.ptr<1>)
+// CHECK-UNKNOWN: llvm.func @pointerCrossWorkgroup(!llvm.ptr)
+spirv.func @pointerCrossWorkgroup(!spirv.ptr<i1, CrossWorkgroup>) "None"
+
+// CHECK-ALL: llvm.func @pointerPrivate(!llvm.ptr)
+spirv.func @pointerPrivate(!spirv.ptr<i1, Private>) "None"
+
+// CHECK-ALL: llvm.mlir.global private @varPrivate() {addr_space = 0 : i32} : i1
+spirv.GlobalVariable @varPrivate : !spirv.ptr<i1, Private>
+
+// CHECK-ALL: llvm.func @pointerFunction(!llvm.ptr)
+spirv.func @pointerFunction(!spirv.ptr<i1, Function>) "None"
+
+// CHECK-OPENCL: llvm.func @pointerGeneric(!llvm.ptr<4>)
+// CHECK-UNKNOWN: llvm.func @pointerGeneric(!llvm.ptr)
+spirv.func @pointerGeneric(!spirv.ptr<i1, Generic>) "None"
+
+// CHECK-ALL: llvm.func @pointerPushConstant(!llvm.ptr)
+spirv.func @pointerPushConstant(!spirv.ptr<i1, PushConstant>) "None"
+
+// CHECK-ALL: llvm.func @pointerAtomicCounter(!llvm.ptr)
+spirv.func @pointerAtomicCounter(!spirv.ptr<i1, AtomicCounter>) "None"
+
+// CHECK-ALL: llvm.func @pointerImage(!llvm.ptr)
+spirv.func @pointerImage(!spirv.ptr<i1, Image>) "None"
+
+// CHECK-ALL: llvm.func @pointerStorageBuffer(!llvm.ptr)
+spirv.func @pointerStorageBuffer(!spirv.ptr<i1, StorageBuffer>) "None"
+
+// CHECK-ALL: llvm.mlir.global external @varStorageBuffer() {addr_space = 0 : i32} : i1
+spirv.GlobalVariable @varStorageBuffer : !spirv.ptr<i1, StorageBuffer>
+
+// CHECK-ALL: llvm.func @pointerCallableDataKHR(!llvm.ptr)
+spirv.func @pointerCallableDataKHR(!spirv.ptr<i1, CallableDataKHR>) "None"
+
+// CHECK-ALL: llvm.func @pointerIncomingCallableDataKHR(!llvm.ptr)
+spirv.func @pointerIncomingCallableDataKHR(!spirv.ptr<i1, IncomingCallableDataKHR>) "None"
+
+// CHECK-ALL: llvm.func @pointerRayPayloadKHR(!llvm.ptr)
+spirv.func @pointerRayPayloadKHR(!spirv.ptr<i1, RayPayloadKHR>) "None"
+
+// CHECK-ALL: llvm.func @pointerHitAttributeKHR(!llvm.ptr)
+spirv.func @pointerHitAttributeKHR(!spirv.ptr<i1, HitAttributeKHR>) "None"
+
+// CHECK-ALL: llvm.func @pointerIncomingRayPayloadKHR(!llvm.ptr)
+spirv.func @pointerIncomingRayPayloadKHR(!spirv.ptr<i1, IncomingRayPayloadKHR>) "None"
+
+// CHECK-ALL: llvm.func @pointerShaderRecordBufferKHR(!llvm.ptr)
+spirv.func @pointerShaderRecordBufferKHR(!spirv.ptr<i1, ShaderRecordBufferKHR>) "None"
+
+// CHECK-ALL: llvm.func @pointerPhysicalStorageBuffer(!llvm.ptr)
+spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr<i1, PhysicalStorageBuffer>) "None"
+
+// CHECK-ALL: llvm.func @pointerCodeSectionINTEL(!llvm.ptr)
+spirv.func @pointerCodeSectionINTEL(!spirv.ptr<i1, CodeSectionINTEL>) "None"
+
+// CHECK-OPENCL: llvm.func @pointerDeviceOnlyINTEL(!llvm.ptr<5>)
+// CHECK-UNKNOWN: llvm.func @pointerDeviceOnlyINTEL(!llvm.ptr)
+spirv.func @pointerDeviceOnlyINTEL(!spirv.ptr<i1, DeviceOnlyINTEL>) "None"
+
+// CHECK-OPENCL: llvm.func @pointerHostOnlyINTEL(!llvm.ptr<6>)
+// CHECK-UNKOWN: llvm.func @pointerHostOnlyINTEL(!llvm.ptr)
+spirv.func @pointerHostOnlyINTEL(!spirv.ptr<i1, HostOnlyINTEL>) "None"
More information about the Mlir-commits
mailing list