[Mlir-commits] [mlir] Use combined-check for type related extension and capability requirements (PR #68033)

Md Abdullah Shahneous Bari llvmlistbot at llvm.org
Mon Oct 2 13:36:19 PDT 2023


https://github.com/mshahneo created https://github.com/llvm/llvm-project/pull/68033

Please review the second commit. First commit is here for the completion purpose (opened as a separate PR (https://github.com/llvm/llvm-project/pull/68031)).

>From 43ebcad9cdeeb647d8cc1be99463db77b61258e4 Mon Sep 17 00:00:00 2001
From: Md Abdullah Shahneous Bari <Md.Abdullah.Shahneous.Bari at intel.com>
Date: Mon, 2 Oct 2023 11:15:08 -0700
Subject: [PATCH 1/2] [mlir][spirv] Extend capabilities and extensions
 requirements checking.

Allow a way to relax requirements for certain capabilities and
extensions (e.g., `elidedCandidates`).

Also add a combined check for capabilities and extensions in
`checkCapabilityAndExtensionRequirements`.
This function checks capabilities, extensions, and
capability infered extension requirements.
---
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 67 +++++++++++++++++--
 1 file changed, 61 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c75d217663a9e09..7bcd36da0c21ee6 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -43,9 +43,13 @@ using namespace mlir;
 template <typename LabelT>
 static LogicalResult checkExtensionRequirements(
     LabelT label, const spirv::TargetEnv &targetEnv,
-    const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
+    const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
+    const ArrayRef<spirv::Extension> elidedCandidates = {}) {
   for (const auto &ors : candidates) {
-    if (targetEnv.allows(ors))
+    if (targetEnv.allows(ors) ||
+        llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) {
+          return llvm::is_contained(ors, elidedExt);
+        }))
       continue;
 
     LLVM_DEBUG({
@@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements(
 template <typename LabelT>
 static LogicalResult checkCapabilityRequirements(
     LabelT label, const spirv::TargetEnv &targetEnv,
-    const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
+    const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
+    const ArrayRef<spirv::Capability> elidedCandidates = {}) {
   for (const auto &ors : candidates) {
-    if (targetEnv.allows(ors))
+    if (targetEnv.allows(ors) ||
+        llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) {
+          return llvm::is_contained(ors, elidedCap);
+        }))
       continue;
 
     LLVM_DEBUG({
@@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements(
   return success();
 }
 
-/// Returns true if the given `storageClass` needs explicit layout when used in
-/// Shader environments.
+/// Check capabilities and extensions requirements
+/// Checks that `capCandidates`, `extCandidates`, and capability
+/// (`capCandidates`) infered extension requirements are possible to be
+/// satisfied with the given `targetEnv`.
+/// It also provides a way to relax requirements for certain capabilities and
+/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to
+/// allow passes to relax certain requirements based on an option (e.g.,
+/// relaxing bitwidth requirement, see `convertScalarType()`,
+/// `ConvertVectorType()`).
+template <typename LabelT>
+static LogicalResult checkCapabilityAndExtensionRequirements(
+    LabelT label, const spirv::TargetEnv &targetEnv,
+    const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates,
+    const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates,
+    const ArrayRef<spirv::Capability> elidedCapCandidates = {},
+    const ArrayRef<spirv::Extension> elidedExtCandidates = {}) {
+  SmallVector<ArrayRef<spirv::Extension>, 8> updatedExtCandidates;
+  llvm::append_range(updatedExtCandidates, extCandidates);
+
+  if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates,
+                                         elidedCapCandidates)))
+    return failure();
+  // Add capablity infered extensions to the list of extension requirement list,
+  // only considers the capabilities that already available in the `targetEnv`.
+
+  // WARNING: Some capabilities are part of both the core SPIR-V
+  // specification and an extension (e.g., 'Groups' capability is part of both
+  // core specification and SPV_AMD_shader_ballot extension, hence we should
+  // relax the capability inferred extension for these cases).
+  static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups};
+  ArrayRef<spirv::Capability> multiModalCapsArrayRef(multiModalCaps,
+                                                     std::size(multiModalCaps));
+
+  for (auto cap : targetEnv.getAttr().getCapabilities()) {
+    if (llvm::any_of(multiModalCapsArrayRef,
+                     [&cap](spirv::Capability mMCap) { return cap == mMCap; }))
+      continue;
+    std::optional<ArrayRef<spirv::Extension>> ext = getExtensions(cap);
+    if (ext)
+      updatedExtCandidates.push_back(*ext);
+  }
+  if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates,
+                                        elidedExtCandidates)))
+    return failure();
+  return success();
+}
+
+/// Returns true if the given `storageClass` needs explicit layout when used
+/// in Shader environments.
 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
   switch (storageClass) {
   case spirv::StorageClass::PhysicalStorageBuffer:

>From f5d497994b1221c18f6ca25c2b324d913fb61f24 Mon Sep 17 00:00:00 2001
From: Md Abdullah Shahneous Bari <Md.Abdullah.Shahneous.Bari at intel.com>
Date: Mon, 2 Oct 2023 11:16:43 -0700
Subject: [PATCH 2/2] [mlir][spirv] Use combined-check for type related
 extension and capability requirements

Replace the seperate extension and capability checking with combined check
`checkCapabilityAndExtensionRequirements()`. This makes the code flow simpler.
Also adds the extra check for capability inferred extension check.

Need for capability inferred extension check:
If a capability is a requirement, the respective extension that implements
it should also become an extension requirement, there were no support for
that check, as a result, the extension requirement had to be added separately.
This separate requirement addition causes problem when a feature is enabled by
multiple capability, and one of the capability is part of an extension. E.g.,
vector size of 16 can be enabled by both "Vector16" and "vectorAnyINTEL"
capability, however, only "vectorAnyINTEL" has an extension requirement
("SPV_INTEL_vector_compute"). Since the process of adding capability
and extension requirement are independent, there is no way, to handle
cases like this. Therefore, for cases like this, enable adding capability
requirement initially, then do the check for capability inferred extension.
---
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 59 +++++++++++++------
 1 file changed, 41 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 7bcd36da0c21ee6..25e6a080642e681 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -285,8 +285,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
   type.getCapabilities(capabilities, storageClass);
 
   // If all requirements are met, then we can accept this type as-is.
-  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
-      succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+  if (succeeded(checkCapabilityAndExtensionRequirements(
+          type, targetEnv, capabilities, extensions)))
     return type;
 
   // Otherwise we need to adjust the type, which really means adjusting the
@@ -397,15 +397,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
   cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
   cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
 
+  // If the bit-width related capabilities and extensions are not met
+  // for lower bit-width (<32-bit), convert it to 32-bit
+  auto elementType =
+      convertScalarType(targetEnv, options, scalarType, storageClass);
+  if (!elementType)
+    return nullptr;
+  type = VectorType::get(type.getShape(), elementType);
+
+  SmallVector<spirv::Capability, 4> elidedCaps;
+  SmallVector<spirv::Extension, 4> elidedExts;
+
+  // Relax the bitwidth requirements for capabilities and extensions
+  if (options.emulateLT32BitScalarTypes) {
+    elidedCaps.push_back(spirv::Capability::Int8);
+    elidedCaps.push_back(spirv::Capability::Int16);
+    elidedCaps.push_back(spirv::Capability::Float16);
+  }
+  // For capabilities whose requirements were relaxed, relax requirements for
+  // the extensions that were infered by those capabilities (e.g., elidedCaps)
+  for (spirv::Capability cap : elidedCaps) {
+    std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap);
+    if (ext)
+      llvm::append_range(elidedExts, *ext);
+  }
   // If all requirements are met, then we can accept this type as-is.
-  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
-      succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+  if (succeeded(checkCapabilityAndExtensionRequirements(
+          type, targetEnv, capabilities, extensions, elidedCaps, elidedExts)))
     return type;
 
-  auto elementType =
-      convertScalarType(targetEnv, options, scalarType, storageClass);
-  if (elementType)
-    return VectorType::get(type.getShape(), elementType);
   return nullptr;
 }
 
@@ -711,8 +731,9 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
   SmallVector<ArrayRef<spirv::Capability>, 2> caps;
   scalarType.getExtensions(exts);
   scalarType.getCapabilities(caps);
-  if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
-      failed(checkExtensionRequirements(type, targetEnv, exts))) {
+
+  if (failed(checkCapabilityAndExtensionRequirements(type, targetEnv, caps,
+                                                     exts))) {
     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
     return castOp.getResult(0);
   }
@@ -1205,16 +1226,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
   SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
   SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
   for (Type valueType : valueTypes) {
-    typeExtensions.clear();
-    cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
-    if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
-                                          typeExtensions)))
-      return false;
-
     typeCapabilities.clear();
     cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
-    if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
-                                           typeCapabilities)))
+    typeExtensions.clear();
+    cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
+    // Checking for capability and extension requirements along with capability
+    // infered extensions.
+    // If a capability is present, the extension that
+    // supports it should also be present, this reduces the burden of adding
+    // extension requirement that may or maynot be added in
+    // CompositeType::getExtensions().
+    if (failed(checkCapabilityAndExtensionRequirements(
+            op->getName(), this->targetEnv, typeCapabilities, typeExtensions)))
       return false;
   }
 



More information about the Mlir-commits mailing list