[clang-tools-extra] [MLIR] SPIRV Target Attribute (PR #69949)

Sang Ik Lee via cfe-commits cfe-commits at lists.llvm.org
Thu Oct 26 14:35:11 PDT 2023


https://github.com/silee2 updated https://github.com/llvm/llvm-project/pull/69949

>From 1a4319cff8d95d5a6a6598f94162be28e56d68a8 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 23 Oct 2023 17:23:54 +0000
Subject: [PATCH 1/3] [MLIR] SPIRV Target Attribute

Create SPIRV Target Attribute to enable GPU compilation pipeline.
The Target Attribute is modeled after the existing spriv.target_env
Plan is to use this new attribute to enable GPU compilation pipeline
for OpenCL kernels only.
The changes do not impact Vulkan shaders using milr-vulkan-runner.
New GPU Dialect transform pass spirv-attach-target is implemented for
attaching attribute from CLI.
gpu-module-to-binary pass now works with GPU module that has SPIRV module
with OpenCL kernel functions inside.
---
 .../mlir/Dialect/GPU/Transforms/Passes.h      |   1 +
 .../mlir/Dialect/GPU/Transforms/Passes.td     |  42 +++++++
 .../mlir/Dialect/SPIRV/IR/SPIRVAttributes.h   |   6 +
 .../mlir/Dialect/SPIRV/IR/SPIRVAttributes.td  |  31 +++++
 mlir/include/mlir/InitAllDialects.h           |   2 +
 mlir/include/mlir/Target/SPIRV/Target.h       |  30 +++++
 mlir/lib/Dialect/GPU/CMakeLists.txt           |   2 +
 .../Dialect/GPU/Transforms/ModuleToBinary.cpp |   2 +
 .../GPU/Transforms/SPIRVAttachTarget.cpp      |  94 ++++++++++++++
 mlir/lib/Target/SPIRV/CMakeLists.txt          |  13 ++
 mlir/lib/Target/SPIRV/Target.cpp              | 115 ++++++++++++++++++
 .../Dialect/GPU/module-to-binary-spirv.mlir   |  13 ++
 .../Dialect/GPU/spirv-attach-targets.mlir     |   7 ++
 mlir/test/Dialect/SPIRV/IR/target.mlir        |  14 +++
 14 files changed, 372 insertions(+)
 create mode 100644 mlir/include/mlir/Target/SPIRV/Target.h
 create mode 100644 mlir/lib/Dialect/GPU/Transforms/SPIRVAttachTarget.cpp
 create mode 100644 mlir/lib/Target/SPIRV/Target.cpp
 create mode 100644 mlir/test/Dialect/GPU/module-to-binary-spirv.mlir
 create mode 100644 mlir/test/Dialect/GPU/spirv-attach-targets.mlir
 create mode 100644 mlir/test/Dialect/SPIRV/IR/target.mlir

diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 2a891a7d24f809a..42fa46b0a57bdee 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -15,6 +15,7 @@
 
 #include "Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Pass/Pass.h"
 #include <optional>
 
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index 3de8e18851369df..44e3e5b6226bfeb 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -188,4 +188,46 @@ def GpuROCDLAttachTarget: Pass<"rocdl-attach-target", ""> {
   ];
 }
 
+def GpuSPIRVAttachTarget: Pass<"spirv-attach-target", ""> {
+  let summary = "Attaches an SPIRV target attribute to a GPU Module.";
+  let description = [{
+    This pass searches for all GPU Modules in the immediate regions and attaches
+    an SPIRV target if the module matches the name specified by the `module` argument.
+
+    Example:
+    ```
+    // File: in1.mlir:
+    gpu.module @nvvm_module_1 {...}
+    gpu.module @spirv_module_1 {...}
+    // mlir-opt --spirv-attach-target="module=spirv.* ver=v1.0 caps=Kernel" in1.mlir
+    gpu.module @nvvm_module_1 {...}
+    gpu.module @spirv_module_1 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>>] {...}
+    ```
+  }];
+  let options = [
+    Option<"moduleMatcher", "module", "std::string",
+           /*default=*/ [{""}],
+           "Regex used to identify the modules to attach the target to.">,
+    Option<"spirvVersion", "ver", "std::string",
+           /*default=*/ "\"v1.0\"",
+           "SPIRV Addressing Model.">,
+    ListOption<"spirvCapabilities", "caps", "std::string",
+           "List of required SPIRV Capabilities">,
+    ListOption<"spirvExtensions", "exts", "std::string",
+           "List of required SPIRV Extensions">,
+    Option<"clientApi", "client_api", "std::string",
+           /*default=*/ "\"Unknown\"",
+           "Client API">,
+    Option<"deviceVendor", "vendor", "std::string",
+           /*default=*/ "\"Unknown\"",
+           "Device Vendor">,
+    Option<"deviceType", "device_type", "std::string",
+           /*default=*/ "\"Unknown\"",
+           "Device Type">,
+    Option<"deviceId", "device_id", "uint32_t",
+           /*default=*/ "mlir::spirv::TargetEnvAttr::kUnknownDeviceID",
+           "Device ID">,
+  ];
+}
+
 #endif // MLIR_DIALECT_GPU_PASSES
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 1d304610a03a8dc..3b914dc4cc82f11 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -17,6 +17,12 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Support/LLVM.h"
 
+namespace mlir {
+namespace spirv {
+class VerCapExtAttr;
+}
+} // namespace mlir
+
 // Pull in TableGen'erated SPIR-V attribute definitions for target and ABI.
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h.inc"
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
index f2c1ee5cfd56eab..e026f9dbfc27e30 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
@@ -166,4 +166,35 @@ def SPIRV_ResourceLimitsAttr : SPIRV_Attr<"ResourceLimits", "resource_limits"> {
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// SPIRV target attribute.
+//===----------------------------------------------------------------------===//
+
+def SPIRV_TargetAttr : SPIRV_Attr<"SPIRVTarget", "target"> {
+  let description = [{
+    GPU target attribute for controlling compilation of SPIRV targets.
+  }];
+  let parameters = (ins
+    "mlir::spirv::VerCapExtAttr":$vce,
+    "mlir::spirv::ResourceLimitsAttr":$resource_limits,
+    DefaultValuedParameter<
+      "mlir::spirv::ClientAPI",
+      "mlir::spirv::ClientAPI::Unknown"
+    >:$client_api,
+    DefaultValuedParameter<
+      "mlir::spirv::Vendor",
+      "mlir::spirv::Vendor::Unknown"
+    >:$vendor_id,
+    DefaultValuedParameter<
+      "mlir::spirv::DeviceType",
+      "mlir::spirv::DeviceType::Unknown"
+    >:$device_type,
+    DefaultValuedParameter<
+      "uint32_t",
+      "mlir::spirv::TargetEnvAttr::kUnknownDeviceID"
+    >:$device_id
+  );
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
 #endif // MLIR_DIALECT_SPIRV_IR_TARGET_AND_ABI
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 00f400aab5d50a0..919d6586f70a520 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -90,6 +90,7 @@
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Target/LLVM/NVVM/Target.h"
 #include "mlir/Target/LLVM/ROCDL/Target.h"
+#include "mlir/Target/SPIRV/Target.h"
 
 namespace mlir {
 
@@ -173,6 +174,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   vector::registerBufferizableOpInterfaceExternalModels(registry);
   NVVM::registerNVVMTargetInterfaceExternalModels(registry);
   ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
+  spirv::registerSPIRVTargetInterfaceExternalModels(registry);
 }
 
 /// Append all the MLIR dialects to the registry contained in the given context.
diff --git a/mlir/include/mlir/Target/SPIRV/Target.h b/mlir/include/mlir/Target/SPIRV/Target.h
new file mode 100644
index 000000000000000..52154afbac66e57
--- /dev/null
+++ b/mlir/include/mlir/Target/SPIRV/Target.h
@@ -0,0 +1,30 @@
+//===- Target.h - MLIR SPIRV target registration ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for attaching the SPIRV target interface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_SPIRV_TARGET_H
+#define MLIR_TARGET_SPIRV_TARGET_H
+
+namespace mlir {
+class DialectRegistry;
+class MLIRContext;
+namespace spirv {
+/// Registers the `TargetAttrInterface` for the `#spirv.target` attribute in the
+/// given registry.
+void registerSPIRVTargetInterfaceExternalModels(DialectRegistry &registry);
+
+/// Registers the `TargetAttrInterface` for the `#spirv.target` attribute in the
+/// registry associated with the given context.
+void registerSPIRVTargetInterfaceExternalModels(MLIRContext &context);
+} // namespace spirv
+} // namespace mlir
+
+#endif // MLIR_TARGET_SPIRV_TARGET_H
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 324d5c136672270..123767075e006b7 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -60,6 +60,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/SerializeToCubin.cpp
   Transforms/SerializeToHsaco.cpp
   Transforms/ShuffleRewriter.cpp
+  Transforms/SPIRVAttachTarget.cpp
   Transforms/ROCDLAttachTarget.cpp
 
   ADDITIONAL_HEADER_DIRS
@@ -95,6 +96,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
   MLIRPass
   MLIRSCFDialect
   MLIRSideEffectInterfaces
+  MLIRSPIRVTarget
   MLIRSupport
   MLIRROCDLTarget
   MLIRTransformUtils
diff --git a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
index 2bf89f8c57903e5..e81992ca8e9a331 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -53,6 +54,7 @@ void GpuModuleToBinaryPass::getDependentDialects(
 #if MLIR_ROCM_CONVERSIONS_ENABLED == 1
   registry.insert<ROCDL::ROCDLDialect>();
 #endif
+  registry.insert<spirv::SPIRVDialect>();
 }
 
 void GpuModuleToBinaryPass::runOnOperation() {
diff --git a/mlir/lib/Dialect/GPU/Transforms/SPIRVAttachTarget.cpp b/mlir/lib/Dialect/GPU/Transforms/SPIRVAttachTarget.cpp
new file mode 100644
index 000000000000000..7366c0eefd05a9d
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/SPIRVAttachTarget.cpp
@@ -0,0 +1,94 @@
+//===- SPIRVAttachTarget.cpp - Attach an SPIRV target ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the `GpuSPIRVAttachTarget` pass, attaching
+// `#spirv.target` attributes to GPU modules.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Target/SPIRV/Target.h"
+#include "llvm/Support/Regex.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_GPUSPIRVATTACHTARGET
+#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::spirv;
+
+namespace {
+struct SPIRVAttachTarget
+    : public impl::GpuSPIRVAttachTargetBase<SPIRVAttachTarget> {
+  using Base::Base;
+
+  void runOnOperation() override;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<spirv::SPIRVDialect>();
+  }
+};
+} // namespace
+
+void SPIRVAttachTarget::runOnOperation() {
+  OpBuilder builder(&getContext());
+  if (!symbolizeVersion(spirvVersion))
+    return signalPassFailure();
+  if (!symbolizeClientAPI(clientApi))
+    return signalPassFailure();
+  if (!symbolizeVendor(deviceVendor))
+    return signalPassFailure();
+  if (!symbolizeDeviceType(deviceType))
+    return signalPassFailure();
+
+  Version version = symbolizeVersion(spirvVersion).value();
+  SmallVector<Capability, 4> capabilities;
+  SmallVector<Extension, 8> extensions;
+  for (auto cap : spirvCapabilities) {
+    if (symbolizeCapability(cap))
+      capabilities.push_back(symbolizeCapability(cap).value());
+  }
+  ArrayRef<Capability> caps(capabilities);
+  for (auto ext : spirvExtensions) {
+    if (symbolizeCapability(ext))
+      extensions.push_back(symbolizeExtension(ext).value());
+  }
+  ArrayRef<Extension> exts(extensions);
+  VerCapExtAttr vce = VerCapExtAttr::get(version, caps, exts, &getContext());
+  auto target = builder.getAttr<SPIRVTargetAttr>(
+      vce, getDefaultResourceLimits(&getContext()),
+      symbolizeClientAPI(clientApi).value(),
+      symbolizeVendor(deviceVendor).value(),
+      symbolizeDeviceType(deviceType).value(), deviceId);
+  llvm::Regex matcher(moduleMatcher);
+  for (Region &region : getOperation()->getRegions())
+    for (Block &block : region.getBlocks())
+      for (auto module : block.getOps<gpu::GPUModuleOp>()) {
+        // Check if the name of the module matches.
+        if (!moduleMatcher.empty() && !matcher.match(module.getName()))
+          continue;
+        // Create the target array.
+        SmallVector<Attribute> targets;
+        if (std::optional<ArrayAttr> attrs = module.getTargets())
+          targets.append(attrs->getValue().begin(), attrs->getValue().end());
+        targets.push_back(target);
+        // Remove any duplicate targets.
+        targets.erase(std::unique(targets.begin(), targets.end()),
+                      targets.end());
+        // Update the target attribute array.
+        module.setTargetsAttr(builder.getArrayAttr(targets));
+      }
+}
diff --git a/mlir/lib/Target/SPIRV/CMakeLists.txt b/mlir/lib/Target/SPIRV/CMakeLists.txt
index 97bb64a74c41c70..f7a7d6e9378dc6d 100644
--- a/mlir/lib/Target/SPIRV/CMakeLists.txt
+++ b/mlir/lib/Target/SPIRV/CMakeLists.txt
@@ -4,6 +4,7 @@ add_subdirectory(Serialization)
 set(LLVM_OPTIONAL_SOURCES
   SPIRVBinaryUtils.cpp
   TranslateRegistration.cpp
+  Target.cpp
   )
 
 add_mlir_translation_library(MLIRSPIRVBinaryUtils
@@ -26,3 +27,15 @@ add_mlir_translation_library(MLIRSPIRVTranslateRegistration
   MLIRSupport
   MLIRTranslateLib
   )
+
+add_mlir_dialect_library(MLIRSPIRVTarget
+  Target.cpp
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSPIRVDialect
+  MLIRSPIRVSerialization
+  MLIRSPIRVDeserialization
+  MLIRSupport
+  MLIRTranslateLib
+  )
diff --git a/mlir/lib/Target/SPIRV/Target.cpp b/mlir/lib/Target/SPIRV/Target.cpp
new file mode 100644
index 000000000000000..e704c696fc5da4a
--- /dev/null
+++ b/mlir/lib/Target/SPIRV/Target.cpp
@@ -0,0 +1,115 @@
+//===- Target.cpp - MLIR SPIRV target compilation ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files defines SPIRV target related functions including registration
+// calls for the `#spirv.target` compilation attribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/SPIRV/Target.h"
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Export.h"
+#include "mlir/Target/SPIRV/Serialization.h"
+
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/Process.h"
+#include "llvm/Support/Program.h"
+#include "llvm/Support/TargetSelect.h"
+
+#include <cstdlib>
+
+using namespace mlir;
+using namespace mlir::spirv;
+
+namespace {
+// Implementation of the `TargetAttrInterface` model.
+class SPIRVTargetAttrImpl
+    : public gpu::TargetAttrInterface::FallbackModel<SPIRVTargetAttrImpl> {
+public:
+  std::optional<SmallVector<char, 0>>
+  serializeToObject(Attribute attribute, Operation *module,
+                    const gpu::TargetOptions &options) const;
+
+  Attribute createObject(Attribute attribute,
+                         const SmallVector<char, 0> &object,
+                         const gpu::TargetOptions &options) const;
+};
+} // namespace
+
+// Register the SPIRV dialect, the SPIRV translation & the target interface.
+void mlir::spirv::registerSPIRVTargetInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, spirv::SPIRVDialect *dialect) {
+    spirv::SPIRVTargetAttr::attachInterface<SPIRVTargetAttrImpl>(*ctx);
+  });
+}
+
+void mlir::spirv::registerSPIRVTargetInterfaceExternalModels(
+    MLIRContext &context) {
+  DialectRegistry registry;
+  registerSPIRVTargetInterfaceExternalModels(registry);
+  context.appendDialectRegistry(registry);
+}
+
+// Reuse from existing serializer
+std::optional<SmallVector<char, 0>> SPIRVTargetAttrImpl::serializeToObject(
+    Attribute attribute, Operation *module,
+    const gpu::TargetOptions &options) const {
+  assert(module && "The module must be non null.");
+  if (!module)
+    return std::nullopt;
+  if (!mlir::isa<gpu::GPUModuleOp>(module)) {
+    module->emitError("Module must be a GPU module.");
+    return std::nullopt;
+  }
+  auto gpuMod = dyn_cast<gpu::GPUModuleOp>(module);
+  auto spvMods = gpuMod.getOps<spirv::ModuleOp>();
+  auto spvMod = *spvMods.begin();
+  llvm::SmallVector<uint32_t, 0> spvBinary;
+
+  spvBinary.clear();
+  // serialize the spv module to spv binary
+  if (mlir::failed(spirv::serialize(spvMod, spvBinary))) {
+    spvMod.emitError() << "Failed to serialize SPIR-V module";
+    return std::nullopt;
+  }
+
+  SmallVector<char, 0> spvData;
+  const char *data = reinterpret_cast<const char *>(spvBinary.data());
+  for (uint32_t i = 0; i < spvBinary.size() * sizeof(uint32_t); i++) {
+    spvData.push_back(*(data + i));
+  }
+
+  spvMod.erase();
+  return spvData;
+}
+
+// Prepare Attribute for gpu.binary with serialized kernel object
+Attribute
+SPIRVTargetAttrImpl::createObject(Attribute attribute,
+                                  const SmallVector<char, 0> &object,
+                                  const gpu::TargetOptions &options) const {
+  auto target = cast<SPIRVTargetAttr>(attribute);
+  gpu::CompilationTarget format = options.getCompilationTarget();
+  DictionaryAttr objectProps;
+  Builder builder(attribute.getContext());
+  return builder.getAttr<gpu::ObjectAttr>(
+      attribute, format,
+      builder.getStringAttr(StringRef(object.data(), object.size())),
+      objectProps);
+}
diff --git a/mlir/test/Dialect/GPU/module-to-binary-spirv.mlir b/mlir/test/Dialect/GPU/module-to-binary-spirv.mlir
new file mode 100644
index 000000000000000..e6cc1ad4edc622b
--- /dev/null
+++ b/mlir/test/Dialect/GPU/module-to-binary-spirv.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s --gpu-module-to-binary | FileCheck %s
+
+module attributes {gpu.container_module} {
+  // CHECK-LABEL:gpu.binary @kernel_module
+  // CHECK:[#gpu.object<#spirv.target<vce = #spirv.vce<v1.0, [Int64, Int16, Kernel, Addresses], []>, resource_limits = <>>, "{{.*}}">]
+  gpu.module @kernel_module [#spirv.target<vce = #spirv.vce<v1.0, [Int64, Int16, Kernel, Addresses], []>, resource_limits = <>>] {
+    spirv.module @__spv__kernel_module Physical64 OpenCL requires #spirv.vce<v1.0, [Int64, Int16, Kernel, Addresses], []> attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_AMD_shader_ballot]>, api=OpenCL, #spirv.resource_limits<>>} {
+      spirv.func @test_kernel(%arg0: !spirv.ptr<!spirv.array<200 x i16>, CrossWorkgroup>, %arg1: !spirv.ptr<!spirv.array<200 x i16>, CrossWorkgroup>, %arg2: !spirv.ptr<!spirv.array<200 x i16>, CrossWorkgroup>) "None" attributes {workgroup_attributions = 0 : i64} {
+        spirv.Return
+      }
+    }
+  }
+}
diff --git a/mlir/test/Dialect/GPU/spirv-attach-targets.mlir b/mlir/test/Dialect/GPU/spirv-attach-targets.mlir
new file mode 100644
index 000000000000000..f9ab3c88d4a69da
--- /dev/null
+++ b/mlir/test/Dialect/GPU/spirv-attach-targets.mlir
@@ -0,0 +1,7 @@
+// RUN: mlir-opt %s --spirv-attach-target='module=spirv.* ver=v1.0 caps=Kernel' | FileCheck %s
+
+module attributes {gpu.container_module} {
+// CHECK: @spirv_module_1 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>>]
+gpu.module @spirv_module_1 {
+}
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/target.mlir b/mlir/test/Dialect/SPIRV/IR/target.mlir
new file mode 100644
index 000000000000000..89b9344f0498126
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/target.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK: @module_1 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>>]
+gpu.module @module_1 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>>] {
+}
+
+// CHECK: @module_2 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>, client_api = OpenCL>]
+gpu.module @module_2 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>, client_api = OpenCL>] {
+}
+
+// CHECK: @module_3 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>, client_api = OpenCL, vendor_id = Intel, device_type = IntegratedGPU>]
+gpu.module @module_3 [#spirv.target<vce = #spirv.vce<v1.0, [Kernel], []>, resource_limits = <>, client_api = OpenCL, vendor_id = Intel, device_type = IntegratedGPU>] {
+}
+

>From 54adf4099c12eee7be39faae95309e530b16e8de Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 23 Oct 2023 17:40:44 +0000
Subject: [PATCH 2/3] Register spirv target attribute.

---
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a51d77dda78bf2f..ad42f77fc536cd3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -14,6 +14,7 @@
 
 #include "SPIRVParsingUtils.h"
 
+#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
@@ -133,6 +134,7 @@ void SPIRVDialect::initialize() {
 
   // Allow unknown operations because SPIR-V is extensible.
   allowUnknownOperations();
+  declarePromisedInterface<SPIRVTargetAttr, gpu::TargetAttrInterface>();
 }
 
 std::string SPIRVDialect::getAttributeName(Decoration decoration) {

>From 08b1d2d538922413081ad049119bf38d7a5d7b68 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 26 Oct 2023 18:34:33 +0000
Subject: [PATCH 3/3] Address reviewer comments.

---
 mlir/lib/Target/SPIRV/Target.cpp | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Target/SPIRV/Target.cpp b/mlir/lib/Target/SPIRV/Target.cpp
index e704c696fc5da4a..a7367479ec74b8e 100644
--- a/mlir/lib/Target/SPIRV/Target.cpp
+++ b/mlir/lib/Target/SPIRV/Target.cpp
@@ -32,6 +32,7 @@
 #include "llvm/Support/TargetSelect.h"
 
 #include <cstdlib>
+#include <cstring>
 
 using namespace mlir;
 using namespace mlir::spirv;
@@ -79,6 +80,10 @@ std::optional<SmallVector<char, 0>> SPIRVTargetAttrImpl::serializeToObject(
   }
   auto gpuMod = dyn_cast<gpu::GPUModuleOp>(module);
   auto spvMods = gpuMod.getOps<spirv::ModuleOp>();
+  // Empty spirv::ModuleOp
+  if (spvMods.empty()) {
+    return std::nullopt;
+  }
   auto spvMod = *spvMods.begin();
   llvm::SmallVector<uint32_t, 0> spvBinary;
 
@@ -89,11 +94,8 @@ std::optional<SmallVector<char, 0>> SPIRVTargetAttrImpl::serializeToObject(
     return std::nullopt;
   }
 
-  SmallVector<char, 0> spvData;
-  const char *data = reinterpret_cast<const char *>(spvBinary.data());
-  for (uint32_t i = 0; i < spvBinary.size() * sizeof(uint32_t); i++) {
-    spvData.push_back(*(data + i));
-  }
+  SmallVector<char, 0> spvData(spvBinary.size() * sizeof(uint32_t), 0);
+  std::memcpy(spvData.data(), spvBinary.data(), spvData.size());
 
   spvMod.erase();
   return spvData;



More information about the cfe-commits mailing list