[Mlir-commits] [mlir] e672f51 - [mlir][spirv] Add a field for client API in target environment

Lei Zhang llvmlistbot at llvm.org
Fri Nov 25 13:38:20 PST 2022


Author: Lei Zhang
Date: 2022-11-25T21:38:00Z
New Revision: e672f5126fcfca650534ee5fd81425df36c76eb6

URL: https://github.com/llvm/llvm-project/commit/e672f5126fcfca650534ee5fd81425df36c76eb6
DIFF: https://github.com/llvm/llvm-project/commit/e672f5126fcfca650534ee5fd81425df36c76eb6.diff

LOG: [mlir][spirv] Add a field for client API in target environment

SPIR-V can be directly consumed by APIs like Vulkan and OpenCL,
where we can use the capability list to diffferentiate. It can
also be used as a compilation target to transcompile to shading
languages like WGSL to target WebGPU. We have no way to tell
that with just the capability list, so we cannot perform certain
transformations only applicable to those targets thus far. So
this commit add a field in the target environment to indicate
the client API for such purposes.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D138732

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
    mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
    mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 3afbaf27dde04..2de4bc225f71f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -138,9 +138,11 @@ class TargetEnvAttr
   using Base::Base;
 
   /// Gets a TargetEnvAttr instance.
-  static TargetEnvAttr get(VerCapExtAttr triple, Vendor vendorID,
-                           DeviceType deviceType, uint32_t deviceId,
-                           ResourceLimitsAttr limits);
+  static TargetEnvAttr get(VerCapExtAttr triple, ResourceLimitsAttr limits,
+                           ClientAPI clientAPI = ClientAPI::Unknown,
+                           Vendor vendorID = Vendor::Unknown,
+                           DeviceType deviceType = DeviceType::Unknown,
+                           uint32_t deviceId = kUnknownDeviceID);
 
   /// Returns the attribute kind's name (without the 'spirv.' prefix).
   static StringRef getKindName();
@@ -161,6 +163,9 @@ class TargetEnvAttr
   /// Returns the target capabilities as an integer array attribute.
   ArrayAttr getCapabilitiesAttr();
 
+  /// Returns the client API.
+  ClientAPI getClientAPI() const;
+
   /// Returns the vendor ID.
   Vendor getVendorID() const;
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 400947a4043e3..536cc136d0c48 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -267,7 +267,7 @@ def SPIRV_DT_IntegratedGPU : I32EnumAttrCase<"IntegratedGPU", 2>;
 // An accelerator other than GPU or CPU
 def SPIRV_DT_Other         : I32EnumAttrCase<"Other", 3>;
 // Information missing.
-def SPIRV_DT_Unknown       : I32EnumAttrCase<"Unknown", 4>;
+def SPIRV_DT_Unknown       : I32EnumAttrCase<"Unknown", 0xffffffff>;
 
 def SPIRV_DeviceTypeAttr : SPIRV_I32EnumAttr<
   "DeviceType", "valid SPIR-V device types", "device_type", [
@@ -283,7 +283,7 @@ def SPIRV_V_Intel       : I32EnumAttrCase<"Intel", 4>;
 def SPIRV_V_NVIDIA      : I32EnumAttrCase<"NVIDIA", 5>;
 def SPIRV_V_Qualcomm    : I32EnumAttrCase<"Qualcomm", 6>;
 def SPIRV_V_SwiftShader : I32EnumAttrCase<"SwiftShader", 7>;
-def SPIRV_V_Unknown     : I32EnumAttrCase<"Unknown", 0xff>;
+def SPIRV_V_Unknown     : I32EnumAttrCase<"Unknown", 0xffffffff>;
 
 def SPIRV_VendorAttr : SPIRV_I32EnumAttr<
   "Vendor", "recognized SPIR-V vendor strings", "vendor", [
@@ -292,6 +292,18 @@ def SPIRV_VendorAttr : SPIRV_I32EnumAttr<
     SPIRV_V_Unknown
   ]>;
 
+def SPIRV_CA_Metal   : I32EnumAttrCase<"Metal", 0>;
+def SPIRV_CA_OpenCL  : I32EnumAttrCase<"OpenCL", 1>;
+def SPIRV_CA_Vulkan  : I32EnumAttrCase<"Vulkan", 2>;
+def SPIRV_CA_WebGPU  : I32EnumAttrCase<"WebGPU", 3>;
+def SPIRV_CA_Unknown : I32EnumAttrCase<"Unknown", 0xffffffff>;
+
+def SPIRV_ClientAPIAttr : SPIRV_I32EnumAttr<
+  "ClientAPI", "recognized SPIR-V client APIs", "client_api", [
+    SPIRV_CA_Metal, SPIRV_CA_OpenCL, SPIRV_CA_Vulkan, SPIRV_CA_WebGPU,
+    SPIRV_CA_Unknown
+  ]>;
+
 //===----------------------------------------------------------------------===//
 // SPIR-V extension definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
index ea5bd3158ea0b..d03ffd0d08698 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
@@ -82,17 +82,18 @@ struct VerCapExtAttributeStorage : public AttributeStorage {
 };
 
 struct TargetEnvAttributeStorage : public AttributeStorage {
-  using KeyTy = std::tuple<Attribute, Vendor, DeviceType, uint32_t, Attribute>;
+  using KeyTy =
+      std::tuple<Attribute, ClientAPI, Vendor, DeviceType, uint32_t, Attribute>;
 
-  TargetEnvAttributeStorage(Attribute triple, Vendor vendorID,
-                            DeviceType deviceType, uint32_t deviceID,
-                            Attribute limits)
-      : triple(triple), limits(limits), vendorID(vendorID),
-        deviceType(deviceType), deviceID(deviceID) {}
+  TargetEnvAttributeStorage(Attribute triple, ClientAPI clientAPI,
+                            Vendor vendorID, DeviceType deviceType,
+                            uint32_t deviceID, Attribute limits)
+      : triple(triple), limits(limits), clientAPI(clientAPI),
+        vendorID(vendorID), deviceType(deviceType), deviceID(deviceID) {}
 
   bool operator==(const KeyTy &key) const {
-    return key ==
-           std::make_tuple(triple, vendorID, deviceType, deviceID, limits);
+    return key == std::make_tuple(triple, clientAPI, vendorID, deviceType,
+                                  deviceID, limits);
   }
 
   static TargetEnvAttributeStorage *
@@ -100,11 +101,12 @@ struct TargetEnvAttributeStorage : public AttributeStorage {
     return new (allocator.allocate<TargetEnvAttributeStorage>())
         TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key),
                                   std::get<2>(key), std::get<3>(key),
-                                  std::get<4>(key));
+                                  std::get<4>(key), std::get<5>(key));
   }
 
   Attribute triple;
   Attribute limits;
+  ClientAPI clientAPI;
   Vendor vendorID;
   DeviceType deviceType;
   uint32_t deviceID;
@@ -282,14 +284,13 @@ spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 // TargetEnvAttr
 //===----------------------------------------------------------------------===//
 
-spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple,
-                                               Vendor vendorID,
-                                               DeviceType deviceType,
-                                               uint32_t deviceID,
-                                               ResourceLimitsAttr limits) {
+spirv::TargetEnvAttr spirv::TargetEnvAttr::get(
+    spirv::VerCapExtAttr triple, ResourceLimitsAttr limits, ClientAPI clientAPI,
+    Vendor vendorID, DeviceType deviceType, uint32_t deviceID) {
   assert(triple && limits && "expected valid triple and limits");
   MLIRContext *context = triple.getContext();
-  return Base::get(context, triple, vendorID, deviceType, deviceID, limits);
+  return Base::get(context, triple, clientAPI, vendorID, deviceType, deviceID,
+                   limits);
 }
 
 StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
@@ -318,6 +319,10 @@ ArrayAttr spirv::TargetEnvAttr::getCapabilitiesAttr() {
   return getTripleAttr().getCapabilitiesAttr();
 }
 
+spirv::ClientAPI spirv::TargetEnvAttr::getClientAPI() const {
+  return getImpl()->clientAPI;
+}
+
 spirv::Vendor spirv::TargetEnvAttr::getVendorID() const {
   return getImpl()->vendorID;
 }
@@ -523,6 +528,22 @@ static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
   if (parser.parseAttribute(tripleAttr) || parser.parseComma())
     return {};
 
+  auto clientAPI = spirv::ClientAPI::Unknown;
+  if (succeeded(parser.parseOptionalKeyword("api"))) {
+    if (parser.parseEqual())
+      return {};
+    auto loc = parser.getCurrentLocation();
+    StringRef apiStr;
+    if (parser.parseKeyword(&apiStr))
+      return {};
+    if (auto apiSymbol = spirv::symbolizeClientAPI(apiStr))
+      clientAPI = *apiSymbol;
+    else
+      parser.emitError(loc, "unknown client API: ") << apiStr;
+    if (parser.parseComma())
+      return {};
+  }
+
   // Parse [vendor[:device-type[:device-id]]]
   Vendor vendorID = Vendor::Unknown;
   DeviceType deviceType = DeviceType::Unknown;
@@ -531,22 +552,20 @@ static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
     auto loc = parser.getCurrentLocation();
     StringRef vendorStr;
     if (succeeded(parser.parseOptionalKeyword(&vendorStr))) {
-      if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) {
+      if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr))
         vendorID = *vendorSymbol;
-      } else {
+      else
         parser.emitError(loc, "unknown vendor: ") << vendorStr;
-      }
 
       if (succeeded(parser.parseOptionalColon())) {
         loc = parser.getCurrentLocation();
         StringRef deviceTypeStr;
         if (parser.parseKeyword(&deviceTypeStr))
           return {};
-        if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) {
+        if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr))
           deviceType = *deviceTypeSymbol;
-        } else {
+        else
           parser.emitError(loc, "unknown device type: ") << deviceTypeStr;
-        }
 
         if (succeeded(parser.parseOptionalColon())) {
           loc = parser.getCurrentLocation();
@@ -563,8 +582,8 @@ static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
   if (parser.parseAttribute(limitsAttr) || parser.parseGreater())
     return {};
 
-  return spirv::TargetEnvAttr::get(tripleAttr, vendorID, deviceType, deviceID,
-                                   limitsAttr);
+  return spirv::TargetEnvAttr::get(tripleAttr, limitsAttr, clientAPI, vendorID,
+                                   deviceType, deviceID);
 }
 
 Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
@@ -616,6 +635,9 @@ static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) {
 static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
   printer << spirv::TargetEnvAttr::getKindName() << "<#spirv.";
   print(targetEnv.getTripleAttr(), printer);
+  auto clientAPI = targetEnv.getClientAPI();
+  if (clientAPI != spirv::ClientAPI::Unknown)
+    printer << ", api=" << clientAPI;
   spirv::Vendor vendorID = targetEnv.getVendorID();
   spirv::DeviceType deviceType = targetEnv.getDeviceType();
   uint32_t deviceID = targetEnv.getDeviceID();

diff  --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
index a36527dfa7c85..bfe95c8ed0b78 100644
--- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/FunctionInterfaces.h"
@@ -170,10 +171,10 @@ spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
   auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0,
                                           {spirv::Capability::Shader},
                                           ArrayRef<Extension>(), context);
-  return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown,
-                                   spirv::DeviceType::Unknown,
-                                   spirv::TargetEnvAttr::kUnknownDeviceID,
-                                   spirv::getDefaultResourceLimits(context));
+  return spirv::TargetEnvAttr::get(
+      triple, spirv::getDefaultResourceLimits(context),
+      spirv::ClientAPI::Unknown, spirv::Vendor::Unknown,
+      spirv::DeviceType::Unknown, spirv::TargetEnvAttr::kUnknownDeviceID);
 }
 
 spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) {

diff  --git a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
index 0184b649080eb..ed84746d49ab0 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
@@ -118,6 +118,24 @@ func.func @target_env() attributes {
 
 // -----
 
+func.func @target_env_client_api() attributes {
+  // CHECK:      spirv.target_env = #spirv.target_env<
+  // CHECK-SAME:   #spirv.vce<v1.0, [], []>,
+  // CHECK-SAME:   api=Metal,
+  // CHECK-SAME:   #spirv.resource_limits<>>
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, api=Metal, #spirv.resource_limits<>>
+} { return }
+
+// -----
+
+func.func @target_env_client_api() attributes {
+  // CHECK:      spirv.target_env = #spirv.target_env
+  // CHECK-NOT:   api=
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, api=Unknown, #spirv.resource_limits<>>
+} { return }
+
+// -----
+
 func.func @target_env_vendor_id() attributes {
   // CHECK:      spirv.target_env = #spirv.target_env<
   // CHECK-SAME:   #spirv.vce<v1.0, [], []>,
@@ -148,6 +166,17 @@ func.func @target_env_vendor_id_device_type_device_id() attributes {
 
 // -----
 
+func.func @target_env_client_api_vendor_id_device_type_device_id() attributes {
+  // CHECK:      spirv.target_env = #spirv.target_env<
+  // CHECK-SAME:   #spirv.vce<v1.0, [], []>,
+  // CHECK-SAME:   api=Vulkan,
+  // CHECK-SAME:   Qualcomm:IntegratedGPU:100925441,
+  // CHECK-SAME:   #spirv.resource_limits<>>
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, api=Vulkan, Qualcomm:IntegratedGPU:0x6040001, #spirv.resource_limits<>>
+} { return }
+
+// -----
+
 func.func @target_env_extra_fields() attributes {
   // expected-error @+3 {{expected '>'}}
   spirv.target_env = #spirv.target_env<


        


More information about the Mlir-commits mailing list