[Mlir-commits] [mlir] [MLIR] SPIRV Target Attribute (PR #69949)
Sang Ik Lee
llvmlistbot at llvm.org
Thu Oct 26 11:34:53 PDT 2023
https://github.com/silee2 updated https://github.com/llvm/llvm-project/pull/69949
>From 1a4319cff8d95d5a6a6598f94162be28e56d68a8 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 23 Oct 2023 17:23:54 +0000
Subject: [PATCH 1/3] [MLIR] SPIRV Target Attribute
Create SPIRV Target Attribute to enable GPU compilation pipeline.
The Target Attribute is modeled after the existing spriv.target_env
Plan is to use this new attribute to enable GPU compilation pipeline
for OpenCL kernels only.
The changes do not impact Vulkan shaders using milr-vulkan-runner.
New GPU Dialect transform pass spirv-attach-target is implemented for
attaching attribute from CLI.
gpu-module-to-binary pass now works with GPU module that has SPIRV module
with OpenCL kernel functions inside.
---
.../mlir/Dialect/GPU/Transforms/Passes.h | 1 +
.../mlir/Dialect/GPU/Transforms/Passes.td | 42 +++++++
.../mlir/Dialect/SPIRV/IR/SPIRVAttributes.h | 6 +
.../mlir/Dialect/SPIRV/IR/SPIRVAttributes.td | 31 +++++
mlir/include/mlir/InitAllDialects.h | 2 +
mlir/include/mlir/Target/SPIRV/Target.h | 30 +++++
mlir/lib/Dialect/GPU/CMakeLists.txt | 2 +
.../Dialect/GPU/Transforms/ModuleToBinary.cpp | 2 +
.../GPU/Transforms/SPIRVAttachTarget.cpp | 94 ++++++++++++++
mlir/lib/Target/SPIRV/CMakeLists.txt | 13 ++
mlir/lib/Target/SPIRV/Target.cpp | 115 ++++++++++++++++++
.../Dialect/GPU/module-to-binary-spirv.mlir | 13 ++
.../Dialect/GPU/spirv-attach-targets.mlir | 7 ++
mlir/test/Dialect/SPIRV/IR/target.mlir | 14 +++
14 files changed, 372 insertions(+)
create mode 100644 mlir/include/mlir/Target/SPIRV/Target.h
create mode 100644 mlir/lib/Dialect/GPU/Transforms/SPIRVAttachTarget.cpp
create mode 100644 mlir/lib/Target/SPIRV/Target.cpp
create mode 100644 mlir/test/Dialect/GPU/module-to-binary-spirv.mlir
create mode 100644 mlir/test/Dialect/GPU/spirv-attach-targets.mlir
create mode 100644 mlir/test/Dialect/SPIRV/IR/target.mlir
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 2a891a7d24f809a..42fa46b0a57bdee 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -15,6 +15,7 @@
#include "Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Pass/Pass.h"
#include <optional>
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index 3de8e18851369df..44e3e5b6226bfeb 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -188,4 +188,46 @@ def GpuROCDLAttachTarget: Pass<"rocdl-attach-target", ""> {
];
}
+def GpuSPIRVAttachTarget: Pass<"spirv-attach-target", ""> {
+ let summary = "Attaches an SPIRV target attribute to a GPU Module.";
+ let description = [{
+ This pass searches for all GPU Modules in the immediate regions and attaches
+ an SPIRV target if the module matches the name specified by the `module` argument.
+
+ Example:
+ ```
+ // File: in1.mlir:
+ gpu.module @nvvm_module_1 {...}
+ gpu.module @spirv_module_1 {...}
+ // mlir-opt --spirv-attach-target="module=spirv.* ver=v1.0 caps=Kernel" in1.mlir
+ gpu.module @nvvm_module_1 {...}
+ gpu.module @spirv_module_1 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>>] {...}
+ ```
+ }];
+ let options = [
+ Option<"moduleMatcher", "module", "std::string",
+ /*default=*/ [{""}],
+ "Regex used to identify the modules to attach the target to.">,
+ Option<"spirvVersion", "ver", "std::string",
+ /*default=*/ "\"v1.0\"",
+ "SPIRV Addressing Model.">,
+ ListOption<"spirvCapabilities", "caps", "std::string",
+ "List of required SPIRV Capabilities">,
+ ListOption<"spirvExtensions", "exts", "std::string",
+ "List of required SPIRV Extensions">,
+ Option<"clientApi", "client_api", "std::string",
+ /*default=*/ "\"Unknown\"",
+ "Client API">,
+ Option<"deviceVendor", "vendor", "std::string",
+ /*default=*/ "\"Unknown\"",
+ "Device Vendor">,
+ Option<"deviceType", "device_type", "std::string",
+ /*default=*/ "\"Unknown\"",
+ "Device Type">,
+ Option<"deviceId", "device_id", "uint32_t",
+ /*default=*/ "mlir::spirv::TargetEnvAttr::kUnknownDeviceID",
+ "Device ID">,
+ ];
+}
+
#endif // MLIR_DIALECT_GPU_PASSES
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 1d304610a03a8dc..3b914dc4cc82f11 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -17,6 +17,12 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LLVM.h"
+namespace mlir {
+namespace spirv {
+class VerCapExtAttr;
+}
+} // namespace mlir
+
// Pull in TableGen'erated SPIR-V attribute definitions for target and ABI.
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h.inc"
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
index f2c1ee5cfd56eab..e026f9dbfc27e30 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
@@ -166,4 +166,35 @@ def SPIRV_ResourceLimitsAttr : SPIRV_Attr<"ResourceLimits", "resource_limits"> {
let assemblyFormat = "`<` struct(params) `>`";
}
+//===----------------------------------------------------------------------===//
+// SPIRV target attribute.
+//===----------------------------------------------------------------------===//
+
+def SPIRV_TargetAttr : SPIRV_Attr<"SPIRVTarget", "target"> {
+ let description = [{
+ GPU target attribute for controlling compilation of SPIRV targets.
+ }];
+ let parameters = (ins
+ "mlir::spirv::VerCapExtAttr":$vce,
+ "mlir::spirv::ResourceLimitsAttr":$resource_limits,
+ DefaultValuedParameter<
+ "mlir::spirv::ClientAPI",
+ "mlir::spirv::ClientAPI::Unknown"
+ >:$client_api,
+ DefaultValuedParameter<
+ "mlir::spirv::Vendor",
+ "mlir::spirv::Vendor::Unknown"
+ >:$vendor_id,
+ DefaultValuedParameter<
+ "mlir::spirv::DeviceType",
+ "mlir::spirv::DeviceType::Unknown"
+ >:$device_type,
+ DefaultValuedParameter<
+ "uint32_t",
+ "mlir::spirv::TargetEnvAttr::kUnknownDeviceID"
+ >:$device_id
+ );
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
#endif // MLIR_DIALECT_SPIRV_IR_TARGET_AND_ABI
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 00f400aab5d50a0..919d6586f70a520 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -90,6 +90,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Target/LLVM/NVVM/Target.h"
#include "mlir/Target/LLVM/ROCDL/Target.h"
+#include "mlir/Target/SPIRV/Target.h"
namespace mlir {
@@ -173,6 +174,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
vector::registerBufferizableOpInterfaceExternalModels(registry);
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
+ spirv::registerSPIRVTargetInterfaceExternalModels(registry);
}
/// Append all the MLIR dialects to the registry contained in the given context.
diff --git a/mlir/include/mlir/Target/SPIRV/Target.h b/mlir/include/mlir/Target/SPIRV/Target.h
new file mode 100644
index 000000000000000..52154afbac66e57
--- /dev/null
+++ b/mlir/include/mlir/Target/SPIRV/Target.h
@@ -0,0 +1,30 @@
+//===- Target.h - MLIR SPIRV target registration ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for attaching the SPIRV target interface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_SPIRV_TARGET_H
+#define MLIR_TARGET_SPIRV_TARGET_H
+
+namespace mlir {
+class DialectRegistry;
+class MLIRContext;
+namespace spirv {
+/// Registers the `TargetAttrInterface` for the `#spirv.target` attribute in the
+/// given registry.
+void registerSPIRVTargetInterfaceExternalModels(DialectRegistry ®istry);
+
+/// Registers the `TargetAttrInterface` for the `#spirv.target` attribute in the
+/// registry associated with the given context.
+void registerSPIRVTargetInterfaceExternalModels(MLIRContext &context);
+} // namespace spirv
+} // namespace mlir
+
+#endif // MLIR_TARGET_SPIRV_TARGET_H
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 324d5c136672270..123767075e006b7 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -60,6 +60,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/SerializeToCubin.cpp
Transforms/SerializeToHsaco.cpp
Transforms/ShuffleRewriter.cpp
+ Transforms/SPIRVAttachTarget.cpp
Transforms/ROCDLAttachTarget.cpp
ADDITIONAL_HEADER_DIRS
@@ -95,6 +96,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
MLIRPass
MLIRSCFDialect
MLIRSideEffectInterfaces
+ MLIRSPIRVTarget
MLIRSupport
MLIRROCDLTarget
MLIRTransformUtils
diff --git a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
index 2bf89f8c57903e5..e81992ca8e9a331 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -53,6 +54,7 @@ void GpuModuleToBinaryPass::getDependentDialects(
#if MLIR_ROCM_CONVERSIONS_ENABLED == 1
registry.insert<ROCDL::ROCDLDialect>();
#endif
+ registry.insert<spirv::SPIRVDialect>();
}
void GpuModuleToBinaryPass::runOnOperation() {
diff --git a/mlir/lib/Dialect/GPU/Transforms/SPIRVAttachTarget.cpp b/mlir/lib/Dialect/GPU/Transforms/SPIRVAttachTarget.cpp
new file mode 100644
index 000000000000000..7366c0eefd05a9d
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/SPIRVAttachTarget.cpp
@@ -0,0 +1,94 @@
+//===- SPIRVAttachTarget.cpp - Attach an SPIRV target ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the `GpuSPIRVAttachTarget` pass, attaching
+// `#spirv.target` attributes to GPU modules.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Target/SPIRV/Target.h"
+#include "llvm/Support/Regex.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_GPUSPIRVATTACHTARGET
+#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::spirv;
+
+namespace {
+struct SPIRVAttachTarget
+ : public impl::GpuSPIRVAttachTargetBase<SPIRVAttachTarget> {
+ using Base::Base;
+
+ void runOnOperation() override;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<spirv::SPIRVDialect>();
+ }
+};
+} // namespace
+
+void SPIRVAttachTarget::runOnOperation() {
+ OpBuilder builder(&getContext());
+ if (!symbolizeVersion(spirvVersion))
+ return signalPassFailure();
+ if (!symbolizeClientAPI(clientApi))
+ return signalPassFailure();
+ if (!symbolizeVendor(deviceVendor))
+ return signalPassFailure();
+ if (!symbolizeDeviceType(deviceType))
+ return signalPassFailure();
+
+ Version version = symbolizeVersion(spirvVersion).value();
+ SmallVector<Capability, 4> capabilities;
+ SmallVector<Extension, 8> extensions;
+ for (auto cap : spirvCapabilities) {
+ if (symbolizeCapability(cap))
+ capabilities.push_back(symbolizeCapability(cap).value());
+ }
+ ArrayRef<Capability> caps(capabilities);
+ for (auto ext : spirvExtensions) {
+ if (symbolizeCapability(ext))
+ extensions.push_back(symbolizeExtension(ext).value());
+ }
+ ArrayRef<Extension> exts(extensions);
+ VerCapExtAttr vce = VerCapExtAttr::get(version, caps, exts, &getContext());
+ auto target = builder.getAttr<SPIRVTargetAttr>(
+ vce, getDefaultResourceLimits(&getContext()),
+ symbolizeClientAPI(clientApi).value(),
+ symbolizeVendor(deviceVendor).value(),
+ symbolizeDeviceType(deviceType).value(), deviceId);
+ llvm::Regex matcher(moduleMatcher);
+ for (Region ®ion : getOperation()->getRegions())
+ for (Block &block : region.getBlocks())
+ for (auto module : block.getOps<gpu::GPUModuleOp>()) {
+ // Check if the name of the module matches.
+ if (!moduleMatcher.empty() && !matcher.match(module.getName()))
+ continue;
+ // Create the target array.
+ SmallVector<Attribute> targets;
+ if (std::optional<ArrayAttr> attrs = module.getTargets())
+ targets.append(attrs->getValue().begin(), attrs->getValue().end());
+ targets.push_back(target);
+ // Remove any duplicate targets.
+ targets.erase(std::unique(targets.begin(), targets.end()),
+ targets.end());
+ // Update the target attribute array.
+ module.setTargetsAttr(builder.getArrayAttr(targets));
+ }
+}
diff --git a/mlir/lib/Target/SPIRV/CMakeLists.txt b/mlir/lib/Target/SPIRV/CMakeLists.txt
index 97bb64a74c41c70..f7a7d6e9378dc6d 100644
--- a/mlir/lib/Target/SPIRV/CMakeLists.txt
+++ b/mlir/lib/Target/SPIRV/CMakeLists.txt
@@ -4,6 +4,7 @@ add_subdirectory(Serialization)
set(LLVM_OPTIONAL_SOURCES
SPIRVBinaryUtils.cpp
TranslateRegistration.cpp
+ Target.cpp
)
add_mlir_translation_library(MLIRSPIRVBinaryUtils
@@ -26,3 +27,15 @@ add_mlir_translation_library(MLIRSPIRVTranslateRegistration
MLIRSupport
MLIRTranslateLib
)
+
+add_mlir_dialect_library(MLIRSPIRVTarget
+ Target.cpp
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSPIRVDialect
+ MLIRSPIRVSerialization
+ MLIRSPIRVDeserialization
+ MLIRSupport
+ MLIRTranslateLib
+ )
diff --git a/mlir/lib/Target/SPIRV/Target.cpp b/mlir/lib/Target/SPIRV/Target.cpp
new file mode 100644
index 000000000000000..e704c696fc5da4a
--- /dev/null
+++ b/mlir/lib/Target/SPIRV/Target.cpp
@@ -0,0 +1,115 @@
+//===- Target.cpp - MLIR SPIRV target compilation ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files defines SPIRV target related functions including registration
+// calls for the `#spirv.target` compilation attribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/SPIRV/Target.h"
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Export.h"
+#include "mlir/Target/SPIRV/Serialization.h"
+
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/Process.h"
+#include "llvm/Support/Program.h"
+#include "llvm/Support/TargetSelect.h"
+
+#include <cstdlib>
+
+using namespace mlir;
+using namespace mlir::spirv;
+
+namespace {
+// Implementation of the `TargetAttrInterface` model.
+class SPIRVTargetAttrImpl
+ : public gpu::TargetAttrInterface::FallbackModel<SPIRVTargetAttrImpl> {
+public:
+ std::optional<SmallVector<char, 0>>
+ serializeToObject(Attribute attribute, Operation *module,
+ const gpu::TargetOptions &options) const;
+
+ Attribute createObject(Attribute attribute,
+ const SmallVector<char, 0> &object,
+ const gpu::TargetOptions &options) const;
+};
+} // namespace
+
+// Register the SPIRV dialect, the SPIRV translation & the target interface.
+void mlir::spirv::registerSPIRVTargetInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, spirv::SPIRVDialect *dialect) {
+ spirv::SPIRVTargetAttr::attachInterface<SPIRVTargetAttrImpl>(*ctx);
+ });
+}
+
+void mlir::spirv::registerSPIRVTargetInterfaceExternalModels(
+ MLIRContext &context) {
+ DialectRegistry registry;
+ registerSPIRVTargetInterfaceExternalModels(registry);
+ context.appendDialectRegistry(registry);
+}
+
+// Reuse from existing serializer
+std::optional<SmallVector<char, 0>> SPIRVTargetAttrImpl::serializeToObject(
+ Attribute attribute, Operation *module,
+ const gpu::TargetOptions &options) const {
+ assert(module && "The module must be non null.");
+ if (!module)
+ return std::nullopt;
+ if (!mlir::isa<gpu::GPUModuleOp>(module)) {
+ module->emitError("Module must be a GPU module.");
+ return std::nullopt;
+ }
+ auto gpuMod = dyn_cast<gpu::GPUModuleOp>(module);
+ auto spvMods = gpuMod.getOps<spirv::ModuleOp>();
+ auto spvMod = *spvMods.begin();
+ llvm::SmallVector<uint32_t, 0> spvBinary;
+
+ spvBinary.clear();
+ // serialize the spv module to spv binary
+ if (mlir::failed(spirv::serialize(spvMod, spvBinary))) {
+ spvMod.emitError() << "Failed to serialize SPIR-V module";
+ return std::nullopt;
+ }
+
+ SmallVector<char, 0> spvData;
+ const char *data = reinterpret_cast<const char *>(spvBinary.data());
+ for (uint32_t i = 0; i < spvBinary.size() * sizeof(uint32_t); i++) {
+ spvData.push_back(*(data + i));
+ }
+
+ spvMod.erase();
+ return spvData;
+}
+
+// Prepare Attribute for gpu.binary with serialized kernel object
+Attribute
+SPIRVTargetAttrImpl::createObject(Attribute attribute,
+ const SmallVector<char, 0> &object,
+ const gpu::TargetOptions &options) const {
+ auto target = cast<SPIRVTargetAttr>(attribute);
+ gpu::CompilationTarget format = options.getCompilationTarget();
+ DictionaryAttr objectProps;
+ Builder builder(attribute.getContext());
+ return builder.getAttr<gpu::ObjectAttr>(
+ attribute, format,
+ builder.getStringAttr(StringRef(object.data(), object.size())),
+ objectProps);
+}
diff --git a/mlir/test/Dialect/GPU/module-to-binary-spirv.mlir b/mlir/test/Dialect/GPU/module-to-binary-spirv.mlir
new file mode 100644
index 000000000000000..e6cc1ad4edc622b
--- /dev/null
+++ b/mlir/test/Dialect/GPU/module-to-binary-spirv.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s --gpu-module-to-binary | FileCheck %s
+
+module attributes {gpu.container_module} {
+ // CHECK-LABEL:gpu.binary @kernel_module
+ // CHECK:[#gpu.object<#spirv.target<vce = #spirv.vce<v1.0, [Int64, Int16, Kernel, Addresses], []>, resource_limits = <>>, "{{.*}}">]
+ gpu.module @kernel_module [#spirv.target<vce = #spirv.vce<v1.0, [Int64, Int16, Kernel, Addresses], []>, resource_limits = <>>] {
+ spirv.module @__spv__kernel_module Physical64 OpenCL requires #spirv.vce<v1.0, [Int64, Int16, Kernel, Addresses], []> attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_AMD_shader_ballot]>, api=OpenCL, #spirv.resource_limits<>>} {
+ spirv.func @test_kernel(%arg0: !spirv.ptr<!spirv.array<200 x i16>, CrossWorkgroup>, %arg1: !spirv.ptr<!spirv.array<200 x i16>, CrossWorkgroup>, %arg2: !spirv.ptr<!spirv.array<200 x i16>, CrossWorkgroup>) "None" attributes {workgroup_attributions = 0 : i64} {
+ spirv.Return
+ }
+ }
+ }
+}
diff --git a/mlir/test/Dialect/GPU/spirv-attach-targets.mlir b/mlir/test/Dialect/GPU/spirv-attach-targets.mlir
new file mode 100644
index 000000000000000..f9ab3c88d4a69da
--- /dev/null
+++ b/mlir/test/Dialect/GPU/spirv-attach-targets.mlir
@@ -0,0 +1,7 @@
+// RUN: mlir-opt %s --spirv-attach-target='module=spirv.* ver=v1.0 caps=Kernel' | FileCheck %s
+
+module attributes {gpu.container_module} {
+// CHECK: @spirv_module_1 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>>]
+gpu.module @spirv_module_1 {
+}
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/target.mlir b/mlir/test/Dialect/SPIRV/IR/target.mlir
new file mode 100644
index 000000000000000..89b9344f0498126
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/target.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK: @module_1 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>>]
+gpu.module @module_1 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>>] {
+}
+
+// CHECK: @module_2 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>, client_api = OpenCL>]
+gpu.module @module_2 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>, client_api = OpenCL>] {
+}
+
+// CHECK: @module_3 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>, client_api = OpenCL, vendor_id = Intel, device_type = IntegratedGPU>]
+gpu.module @module_3 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>, client_api = OpenCL, vendor_id = Intel, device_type = IntegratedGPU>] {
+}
+
>From 54adf4099c12eee7be39faae95309e530b16e8de Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 23 Oct 2023 17:40:44 +0000
Subject: [PATCH 2/3] Register spirv target attribute.
---
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a51d77dda78bf2f..ad42f77fc536cd3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -14,6 +14,7 @@
#include "SPIRVParsingUtils.h"
+#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
@@ -133,6 +134,7 @@ void SPIRVDialect::initialize() {
// Allow unknown operations because SPIR-V is extensible.
allowUnknownOperations();
+ declarePromisedInterface<SPIRVTargetAttr, gpu::TargetAttrInterface>();
}
std::string SPIRVDialect::getAttributeName(Decoration decoration) {
>From 08b1d2d538922413081ad049119bf38d7a5d7b68 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 26 Oct 2023 18:34:33 +0000
Subject: [PATCH 3/3] Address reviewer comments.
---
mlir/lib/Target/SPIRV/Target.cpp | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Target/SPIRV/Target.cpp b/mlir/lib/Target/SPIRV/Target.cpp
index e704c696fc5da4a..a7367479ec74b8e 100644
--- a/mlir/lib/Target/SPIRV/Target.cpp
+++ b/mlir/lib/Target/SPIRV/Target.cpp
@@ -32,6 +32,7 @@
#include "llvm/Support/TargetSelect.h"
#include <cstdlib>
+#include <cstring>
using namespace mlir;
using namespace mlir::spirv;
@@ -79,6 +80,10 @@ std::optional<SmallVector<char, 0>> SPIRVTargetAttrImpl::serializeToObject(
}
auto gpuMod = dyn_cast<gpu::GPUModuleOp>(module);
auto spvMods = gpuMod.getOps<spirv::ModuleOp>();
+ // Empty spirv::ModuleOp
+ if (spvMods.empty()) {
+ return std::nullopt;
+ }
auto spvMod = *spvMods.begin();
llvm::SmallVector<uint32_t, 0> spvBinary;
@@ -89,11 +94,8 @@ std::optional<SmallVector<char, 0>> SPIRVTargetAttrImpl::serializeToObject(
return std::nullopt;
}
- SmallVector<char, 0> spvData;
- const char *data = reinterpret_cast<const char *>(spvBinary.data());
- for (uint32_t i = 0; i < spvBinary.size() * sizeof(uint32_t); i++) {
- spvData.push_back(*(data + i));
- }
+ SmallVector<char, 0> spvData(spvBinary.size() * sizeof(uint32_t), 0);
+ std::memcpy(spvData.data(), spvBinary.data(), spvData.size());
spvMod.erase();
return spvData;
More information about the Mlir-commits
mailing list