[Mlir-commits] [mlir] 2dace04 - [mlir][spirv] Implement gpu::TargetAttrInterface (#69949)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 5 08:11:57 PST 2023


Author: Sang Ik Lee
Date: 2023-11-05T08:11:53-08:00
New Revision: 2dace0452107a43ed030f1156d52282dd6495de2

URL: https://github.com/llvm/llvm-project/commit/2dace0452107a43ed030f1156d52282dd6495de2
DIFF: https://github.com/llvm/llvm-project/commit/2dace0452107a43ed030f1156d52282dd6495de2.diff

LOG: [mlir][spirv] Implement gpu::TargetAttrInterface (#69949)

This commit implements gpu::TargetAttrInterface for SPIR-V target
attribute. The plan is to use this to enable GPU compilation pipeline
for OpenCL kernels later.

The changes do not impact Vulkan shaders using milr-vulkan-runner.
New GPU Dialect transform pass spirv-attach-target is implemented for
attaching attribute from CLI.

gpu-module-to-binary pass now works with GPU module that has SPIR-V
module with OpenCL kernel functions inside.

Added: 
    mlir/include/mlir/Target/SPIRV/Target.h
    mlir/lib/Dialect/GPU/Transforms/SPIRVAttachTarget.cpp
    mlir/lib/Target/SPIRV/Target.cpp
    mlir/test/Dialect/GPU/module-to-binary-spirv.mlir
    mlir/test/Dialect/GPU/spirv-attach-targets.mlir
    mlir/test/Dialect/SPIRV/IR/target.mlir

Modified: 
    mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
    mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/GPU/CMakeLists.txt
    mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
    mlir/lib/Target/SPIRV/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 2a891a7d24f809a..42fa46b0a57bdee 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -15,6 +15,7 @@
 
 #include "Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Pass/Pass.h"
 #include <optional>
 

diff  --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index 3de8e18851369df..b22d26d49dbdb0e 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -188,4 +188,48 @@ def GpuROCDLAttachTarget: Pass<"rocdl-attach-target", ""> {
   ];
 }
 
+def GpuSPIRVAttachTarget: Pass<"spirv-attach-target", ""> {
+  let summary = "Attaches an SPIR-V target attribute to a GPU Module.";
+  let description = [{
+    This pass searches for all GPU Modules in the immediate regions and attaches
+    an SPIR-V target if the module matches the name specified by the `module` argument.
+
+    Example:
+    ```
+    // Given the following file: in1.mlir:
+    gpu.module @nvvm_module_1 {...}
+    gpu.module @spirv_module_1 {...}
+    // With
+    // mlir-opt --spirv-attach-target="module=spirv.* ver=v1.0 caps=Kernel" in1.mlir
+    // it will generate,
+    gpu.module @nvvm_module_1 {...}
+    gpu.module @spirv_module_1 [#spirv.target<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>] {...}
+    ```
+  }];
+  let options = [
+    Option<"moduleMatcher", "module", "std::string",
+           /*default=*/ [{""}],
+           "Regex used to identify the modules to attach the target to.">,
+    Option<"spirvVersion", "ver", "std::string",
+           /*default=*/ "\"v1.0\"",
+           "SPIR-V Version.">,
+    ListOption<"spirvCapabilities", "caps", "std::string",
+           "List of supported SPIR-V Capabilities">,
+    ListOption<"spirvExtensions", "exts", "std::string",
+           "List of supported SPIR-V Extensions">,
+    Option<"clientApi", "client_api", "std::string",
+           /*default=*/ "\"Unknown\"",
+           "Client API">,
+    Option<"deviceVendor", "vendor", "std::string",
+           /*default=*/ "\"Unknown\"",
+           "Device Vendor">,
+    Option<"deviceType", "device_type", "std::string",
+           /*default=*/ "\"Unknown\"",
+           "Device Type">,
+    Option<"deviceId", "device_id", "uint32_t",
+           /*default=*/ "mlir::spirv::TargetEnvAttr::kUnknownDeviceID",
+           "Device ID">,
+  ];
+}
+
 #endif // MLIR_DIALECT_GPU_PASSES

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 1d304610a03a8dc..3b914dc4cc82f11 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -17,6 +17,12 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Support/LLVM.h"
 
+namespace mlir {
+namespace spirv {
+class VerCapExtAttr;
+}
+} // namespace mlir
+
 // Pull in TableGen'erated SPIR-V attribute definitions for target and ABI.
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h.inc"

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index b4eb891f37c1f75..19a62cadaa2e04f 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -91,6 +91,7 @@
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Target/LLVM/NVVM/Target.h"
 #include "mlir/Target/LLVM/ROCDL/Target.h"
+#include "mlir/Target/SPIRV/Target.h"
 
 namespace mlir {
 
@@ -175,6 +176,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   vector::registerSubsetOpInterfaceExternalModels(registry);
   NVVM::registerNVVMTargetInterfaceExternalModels(registry);
   ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
+  spirv::registerSPIRVTargetInterfaceExternalModels(registry);
 }
 
 /// Append all the MLIR dialects to the registry contained in the given context.

diff  --git a/mlir/include/mlir/Target/SPIRV/Target.h b/mlir/include/mlir/Target/SPIRV/Target.h
new file mode 100644
index 000000000000000..fea6471ced7c2fc
--- /dev/null
+++ b/mlir/include/mlir/Target/SPIRV/Target.h
@@ -0,0 +1,30 @@
+//===- Target.h - MLIR SPIR-V target registration ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for attaching the SPIR-V target interface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_SPIRV_TARGET_H
+#define MLIR_TARGET_SPIRV_TARGET_H
+
+namespace mlir {
+class DialectRegistry;
+class MLIRContext;
+namespace spirv {
+/// Registers the `TargetAttrInterface` for the `#spirv.target_env` attribute in
+/// the given registry.
+void registerSPIRVTargetInterfaceExternalModels(DialectRegistry &registry);
+
+/// Registers the `TargetAttrInterface` for the `#spirv.target_env` attribute in
+/// the registry associated with the given context.
+void registerSPIRVTargetInterfaceExternalModels(MLIRContext &context);
+} // namespace spirv
+} // namespace mlir
+
+#endif // MLIR_TARGET_SPIRV_TARGET_H

diff  --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 1601413c49f1fc2..09a3cd06788bc5e 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -60,6 +60,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/SerializeToCubin.cpp
   Transforms/SerializeToHsaco.cpp
   Transforms/ShuffleRewriter.cpp
+  Transforms/SPIRVAttachTarget.cpp
   Transforms/ROCDLAttachTarget.cpp
 
   ADDITIONAL_HEADER_DIRS
@@ -95,6 +96,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
   MLIRPass
   MLIRSCFDialect
   MLIRSideEffectInterfaces
+  MLIRSPIRVTarget
   MLIRSupport
   MLIRROCDLTarget
   MLIRTransformUtils

diff  --git a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
index 2bf89f8c57903e5..e81992ca8e9a331 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -53,6 +54,7 @@ void GpuModuleToBinaryPass::getDependentDialects(
 #if MLIR_ROCM_CONVERSIONS_ENABLED == 1
   registry.insert<ROCDL::ROCDLDialect>();
 #endif
+  registry.insert<spirv::SPIRVDialect>();
 }
 
 void GpuModuleToBinaryPass::runOnOperation() {

diff  --git a/mlir/lib/Dialect/GPU/Transforms/SPIRVAttachTarget.cpp b/mlir/lib/Dialect/GPU/Transforms/SPIRVAttachTarget.cpp
new file mode 100644
index 000000000000000..eece62b9c6cb9c8
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/SPIRVAttachTarget.cpp
@@ -0,0 +1,95 @@
+//===- SPIRVAttachTarget.cpp - Attach an SPIR-V target --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the `GPUSPIRVAttachTarget` pass, attaching
+// `#spirv.target_env` attributes to GPU modules.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Target/SPIRV/Target.h"
+#include "llvm/Support/Regex.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_GPUSPIRVATTACHTARGET
+#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::spirv;
+
+namespace {
+struct SPIRVAttachTarget
+    : public impl::GpuSPIRVAttachTargetBase<SPIRVAttachTarget> {
+  using Base::Base;
+
+  void runOnOperation() override;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<spirv::SPIRVDialect>();
+  }
+};
+} // namespace
+
+void SPIRVAttachTarget::runOnOperation() {
+  OpBuilder builder(&getContext());
+  auto versionSymbol = symbolizeVersion(spirvVersion);
+  if (!versionSymbol)
+    return signalPassFailure();
+  auto apiSymbol = symbolizeClientAPI(clientApi);
+  if (!apiSymbol)
+    return signalPassFailure();
+  auto vendorSymbol = symbolizeVendor(deviceVendor);
+  if (!vendorSymbol)
+    return signalPassFailure();
+  auto deviceTypeSymbol = symbolizeDeviceType(deviceType);
+  if (!deviceTypeSymbol)
+    return signalPassFailure();
+
+  Version version = versionSymbol.value();
+  SmallVector<Capability, 4> capabilities;
+  SmallVector<Extension, 8> extensions;
+  for (auto cap : spirvCapabilities) {
+    auto capSymbol = symbolizeCapability(cap);
+    if (capSymbol)
+      capabilities.push_back(capSymbol.value());
+  }
+  ArrayRef<Capability> caps(capabilities);
+  for (auto ext : spirvExtensions) {
+    auto extSymbol = symbolizeExtension(ext);
+    if (extSymbol)
+      extensions.push_back(extSymbol.value());
+  }
+  ArrayRef<Extension> exts(extensions);
+  VerCapExtAttr vce = VerCapExtAttr::get(version, caps, exts, &getContext());
+  auto target = TargetEnvAttr::get(vce, getDefaultResourceLimits(&getContext()),
+                                   apiSymbol.value(), vendorSymbol.value(),
+                                   deviceTypeSymbol.value(), deviceId);
+  llvm::Regex matcher(moduleMatcher);
+  getOperation()->walk([&](gpu::GPUModuleOp gpuModule) {
+    // Check if the name of the module matches.
+    if (!moduleMatcher.empty() && !matcher.match(gpuModule.getName()))
+      return;
+    // Create the target array.
+    SmallVector<Attribute> targets;
+    if (std::optional<ArrayAttr> attrs = gpuModule.getTargets())
+      targets.append(attrs->getValue().begin(), attrs->getValue().end());
+    targets.push_back(target);
+    // Remove any duplicate targets.
+    targets.erase(std::unique(targets.begin(), targets.end()), targets.end());
+    // Update the target attribute array.
+    gpuModule.setTargetsAttr(builder.getArrayAttr(targets));
+  });
+}

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a51d77dda78bf2f..2de849dc4465e37 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -14,6 +14,7 @@
 
 #include "SPIRVParsingUtils.h"
 
+#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
@@ -133,6 +134,7 @@ void SPIRVDialect::initialize() {
 
   // Allow unknown operations because SPIR-V is extensible.
   allowUnknownOperations();
+  declarePromisedInterface<TargetEnvAttr, gpu::TargetAttrInterface>();
 }
 
 std::string SPIRVDialect::getAttributeName(Decoration decoration) {

diff  --git a/mlir/lib/Target/SPIRV/CMakeLists.txt b/mlir/lib/Target/SPIRV/CMakeLists.txt
index 97bb64a74c41c70..f7a7d6e9378dc6d 100644
--- a/mlir/lib/Target/SPIRV/CMakeLists.txt
+++ b/mlir/lib/Target/SPIRV/CMakeLists.txt
@@ -4,6 +4,7 @@ add_subdirectory(Serialization)
 set(LLVM_OPTIONAL_SOURCES
   SPIRVBinaryUtils.cpp
   TranslateRegistration.cpp
+  Target.cpp
   )
 
 add_mlir_translation_library(MLIRSPIRVBinaryUtils
@@ -26,3 +27,15 @@ add_mlir_translation_library(MLIRSPIRVTranslateRegistration
   MLIRSupport
   MLIRTranslateLib
   )
+
+add_mlir_dialect_library(MLIRSPIRVTarget
+  Target.cpp
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSPIRVDialect
+  MLIRSPIRVSerialization
+  MLIRSPIRVDeserialization
+  MLIRSupport
+  MLIRTranslateLib
+  )

diff  --git a/mlir/lib/Target/SPIRV/Target.cpp b/mlir/lib/Target/SPIRV/Target.cpp
new file mode 100644
index 000000000000000..1b6bfbac38b6193
--- /dev/null
+++ b/mlir/lib/Target/SPIRV/Target.cpp
@@ -0,0 +1,114 @@
+//===- Target.cpp - MLIR SPIR-V target compilation --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files defines SPIR-V target related functions including registration
+// calls for the `#spirv.target_env` compilation attribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/SPIRV/Target.h"
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Export.h"
+#include "mlir/Target/SPIRV/Serialization.h"
+
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/Process.h"
+#include "llvm/Support/Program.h"
+#include "llvm/Support/TargetSelect.h"
+
+#include <cstdlib>
+#include <cstring>
+
+using namespace mlir;
+using namespace mlir::spirv;
+
+namespace {
+// SPIR-V implementation of the gpu:TargetAttrInterface.
+class SPIRVTargetAttrImpl
+    : public gpu::TargetAttrInterface::FallbackModel<SPIRVTargetAttrImpl> {
+public:
+  std::optional<SmallVector<char, 0>>
+  serializeToObject(Attribute attribute, Operation *module,
+                    const gpu::TargetOptions &options) const;
+
+  Attribute createObject(Attribute attribute,
+                         const SmallVector<char, 0> &object,
+                         const gpu::TargetOptions &options) const;
+};
+} // namespace
+
+// Register the SPIR-V dialect, the SPIR-V translation & the target interface.
+void mlir::spirv::registerSPIRVTargetInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, spirv::SPIRVDialect *dialect) {
+    spirv::TargetEnvAttr::attachInterface<SPIRVTargetAttrImpl>(*ctx);
+  });
+}
+
+void mlir::spirv::registerSPIRVTargetInterfaceExternalModels(
+    MLIRContext &context) {
+  DialectRegistry registry;
+  registerSPIRVTargetInterfaceExternalModels(registry);
+  context.appendDialectRegistry(registry);
+}
+
+// Reuse from existing serializer
+std::optional<SmallVector<char, 0>> SPIRVTargetAttrImpl::serializeToObject(
+    Attribute attribute, Operation *module,
+    const gpu::TargetOptions &options) const {
+  if (!module)
+    return std::nullopt;
+  auto gpuMod = dyn_cast<gpu::GPUModuleOp>(module);
+  if (!gpuMod) {
+    module->emitError("expected to be a gpu.module op");
+    return std::nullopt;
+  }
+  auto spvMods = gpuMod.getOps<spirv::ModuleOp>();
+  if (spvMods.empty())
+    return std::nullopt;
+
+  auto spvMod = *spvMods.begin();
+  llvm::SmallVector<uint32_t, 0> spvBinary;
+
+  spvBinary.clear();
+  // Serialize the spirv.module op to SPIR-V blob.
+  if (mlir::failed(spirv::serialize(spvMod, spvBinary))) {
+    spvMod.emitError() << "failed to serialize SPIR-V module";
+    return std::nullopt;
+  }
+
+  SmallVector<char, 0> spvData(spvBinary.size() * sizeof(uint32_t), 0);
+  std::memcpy(spvData.data(), spvBinary.data(), spvData.size());
+
+  spvMod.erase();
+  return spvData;
+}
+
+// Prepare Attribute for gpu.binary with serialized kernel object
+Attribute
+SPIRVTargetAttrImpl::createObject(Attribute attribute,
+                                  const SmallVector<char, 0> &object,
+                                  const gpu::TargetOptions &options) const {
+  gpu::CompilationTarget format = options.getCompilationTarget();
+  DictionaryAttr objectProps;
+  Builder builder(attribute.getContext());
+  return builder.getAttr<gpu::ObjectAttr>(
+      attribute, format,
+      builder.getStringAttr(StringRef(object.data(), object.size())),
+      objectProps);
+}

diff  --git a/mlir/test/Dialect/GPU/module-to-binary-spirv.mlir b/mlir/test/Dialect/GPU/module-to-binary-spirv.mlir
new file mode 100644
index 000000000000000..d62f43927980174
--- /dev/null
+++ b/mlir/test/Dialect/GPU/module-to-binary-spirv.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s --gpu-module-to-binary | FileCheck %s
+
+module attributes {gpu.container_module} {
+  // CHECK-LABEL:gpu.binary @kernel_module
+  // CHECK:[#gpu.object<#spirv.target_env<#spirv.vce<v1.0, [Int64, Int16, Kernel, Addresses], []>, #spirv.resource_limits<>>, "{{.*}}">]
+  gpu.module @kernel_module [#spirv.target_env<#spirv.vce<v1.0, [Int64, Int16, Kernel, Addresses], []>, #spirv.resource_limits<>>] {
+    spirv.module @__spv__kernel_module Physical64 OpenCL requires #spirv.vce<v1.0, [Int64, Int16, Kernel, Addresses], []> attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_AMD_shader_ballot]>, api=OpenCL, #spirv.resource_limits<>>} {
+      spirv.func @test_kernel(%arg0: !spirv.ptr<!spirv.array<200 x i16>, CrossWorkgroup>, %arg1: !spirv.ptr<!spirv.array<200 x i16>, CrossWorkgroup>, %arg2: !spirv.ptr<!spirv.array<200 x i16>, CrossWorkgroup>) "None" attributes {workgroup_attributions = 0 : i64} {
+        spirv.Return
+      }
+    }
+  }
+}

diff  --git a/mlir/test/Dialect/GPU/spirv-attach-targets.mlir b/mlir/test/Dialect/GPU/spirv-attach-targets.mlir
new file mode 100644
index 000000000000000..2ab748834e49fac
--- /dev/null
+++ b/mlir/test/Dialect/GPU/spirv-attach-targets.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt %s --spirv-attach-target='module=spirv.* ver=v1.0 caps=Kernel' | FileCheck %s
+// RUN: mlir-opt %s --spirv-attach-target='module=spirv_warm.* ver=v1.0 caps=Kernel' | FileCheck %s --check-prefix=CHECK_WARM
+
+module attributes {gpu.container_module} {
+//      CHECK: @spirv_hot_module [#spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>]
+// CHECK_WARM: @spirv_hot_module {
+gpu.module @spirv_hot_module {
+}
+//      CHECK: @spirv_warm_module [#spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>]
+// CHECK_WARM: @spirv_warm_module [#spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>]
+gpu.module @spirv_warm_module {
+}
+//      CHECK: @spirv_cold_module [#spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>]
+// CHECK_WARM: @spirv_cold_module {
+gpu.module @spirv_cold_module {
+}
+}

diff  --git a/mlir/test/Dialect/SPIRV/IR/target.mlir b/mlir/test/Dialect/SPIRV/IR/target.mlir
new file mode 100644
index 000000000000000..6c60fe79f20bb95
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/target.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK: @module_1 [#spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>]
+gpu.module @module_1 [#spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>] {
+}
+
+// CHECK: @module_2 [#spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, api=OpenCL, #spirv.resource_limits<>>]
+gpu.module @module_2 [#spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, api=OpenCL, #spirv.resource_limits<>>] {
+}
+
+// CHECK: @module_3 [#spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, api=OpenCL, Intel:IntegratedGPU, #spirv.resource_limits<>>]
+gpu.module @module_3 [#spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, api=OpenCL, Intel:IntegratedGPU, #spirv.resource_limits<>>] {
+}
+


        


More information about the Mlir-commits mailing list