[Mlir-commits] [mlir] Use combined-check for type related extension and capability requirements (PR #68033)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 2 13:37:27 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
Please review the second commit. First commit is here for the completion purpose (opened as a separate PR (https://github.com/llvm/llvm-project/pull/68031)).
---
Full diff: https://github.com/llvm/llvm-project/pull/68033.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+102-24)
``````````diff
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c75d217663a9e09..25e6a080642e681 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -43,9 +43,13 @@ using namespace mlir;
template <typename LabelT>
static LogicalResult checkExtensionRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
- const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
+ const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
+ const ArrayRef<spirv::Extension> elidedCandidates = {}) {
for (const auto &ors : candidates) {
- if (targetEnv.allows(ors))
+ if (targetEnv.allows(ors) ||
+ llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) {
+ return llvm::is_contained(ors, elidedExt);
+ }))
continue;
LLVM_DEBUG({
@@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements(
template <typename LabelT>
static LogicalResult checkCapabilityRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
- const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
+ const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
+ const ArrayRef<spirv::Capability> elidedCandidates = {}) {
for (const auto &ors : candidates) {
- if (targetEnv.allows(ors))
+ if (targetEnv.allows(ors) ||
+ llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) {
+ return llvm::is_contained(ors, elidedCap);
+ }))
continue;
LLVM_DEBUG({
@@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements(
return success();
}
-/// Returns true if the given `storageClass` needs explicit layout when used in
-/// Shader environments.
+/// Check capabilities and extensions requirements
+/// Checks that `capCandidates`, `extCandidates`, and capability
+/// (`capCandidates`) infered extension requirements are possible to be
+/// satisfied with the given `targetEnv`.
+/// It also provides a way to relax requirements for certain capabilities and
+/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to
+/// allow passes to relax certain requirements based on an option (e.g.,
+/// relaxing bitwidth requirement, see `convertScalarType()`,
+/// `ConvertVectorType()`).
+template <typename LabelT>
+static LogicalResult checkCapabilityAndExtensionRequirements(
+ LabelT label, const spirv::TargetEnv &targetEnv,
+ const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates,
+ const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates,
+ const ArrayRef<spirv::Capability> elidedCapCandidates = {},
+ const ArrayRef<spirv::Extension> elidedExtCandidates = {}) {
+ SmallVector<ArrayRef<spirv::Extension>, 8> updatedExtCandidates;
+ llvm::append_range(updatedExtCandidates, extCandidates);
+
+ if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates,
+ elidedCapCandidates)))
+ return failure();
+ // Add capablity infered extensions to the list of extension requirement list,
+ // only considers the capabilities that already available in the `targetEnv`.
+
+ // WARNING: Some capabilities are part of both the core SPIR-V
+ // specification and an extension (e.g., 'Groups' capability is part of both
+ // core specification and SPV_AMD_shader_ballot extension, hence we should
+ // relax the capability inferred extension for these cases).
+ static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups};
+ ArrayRef<spirv::Capability> multiModalCapsArrayRef(multiModalCaps,
+ std::size(multiModalCaps));
+
+ for (auto cap : targetEnv.getAttr().getCapabilities()) {
+ if (llvm::any_of(multiModalCapsArrayRef,
+ [&cap](spirv::Capability mMCap) { return cap == mMCap; }))
+ continue;
+ std::optional<ArrayRef<spirv::Extension>> ext = getExtensions(cap);
+ if (ext)
+ updatedExtCandidates.push_back(*ext);
+ }
+ if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates,
+ elidedExtCandidates)))
+ return failure();
+ return success();
+}
+
+/// Returns true if the given `storageClass` needs explicit layout when used
+/// in Shader environments.
static bool needsExplicitLayout(spirv::StorageClass storageClass) {
switch (storageClass) {
case spirv::StorageClass::PhysicalStorageBuffer:
@@ -230,8 +285,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
type.getCapabilities(capabilities, storageClass);
// If all requirements are met, then we can accept this type as-is.
- if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
- succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+ if (succeeded(checkCapabilityAndExtensionRequirements(
+ type, targetEnv, capabilities, extensions)))
return type;
// Otherwise we need to adjust the type, which really means adjusting the
@@ -342,15 +397,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
+ // If the bit-width related capabilities and extensions are not met
+ // for lower bit-width (<32-bit), convert it to 32-bit
+ auto elementType =
+ convertScalarType(targetEnv, options, scalarType, storageClass);
+ if (!elementType)
+ return nullptr;
+ type = VectorType::get(type.getShape(), elementType);
+
+ SmallVector<spirv::Capability, 4> elidedCaps;
+ SmallVector<spirv::Extension, 4> elidedExts;
+
+ // Relax the bitwidth requirements for capabilities and extensions
+ if (options.emulateLT32BitScalarTypes) {
+ elidedCaps.push_back(spirv::Capability::Int8);
+ elidedCaps.push_back(spirv::Capability::Int16);
+ elidedCaps.push_back(spirv::Capability::Float16);
+ }
+ // For capabilities whose requirements were relaxed, relax requirements for
+ // the extensions that were infered by those capabilities (e.g., elidedCaps)
+ for (spirv::Capability cap : elidedCaps) {
+ std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap);
+ if (ext)
+ llvm::append_range(elidedExts, *ext);
+ }
// If all requirements are met, then we can accept this type as-is.
- if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
- succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+ if (succeeded(checkCapabilityAndExtensionRequirements(
+ type, targetEnv, capabilities, extensions, elidedCaps, elidedExts)))
return type;
- auto elementType =
- convertScalarType(targetEnv, options, scalarType, storageClass);
- if (elementType)
- return VectorType::get(type.getShape(), elementType);
return nullptr;
}
@@ -656,8 +731,9 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
SmallVector<ArrayRef<spirv::Capability>, 2> caps;
scalarType.getExtensions(exts);
scalarType.getCapabilities(caps);
- if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
- failed(checkExtensionRequirements(type, targetEnv, exts))) {
+
+ if (failed(checkCapabilityAndExtensionRequirements(type, targetEnv, caps,
+ exts))) {
auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return castOp.getResult(0);
}
@@ -1150,16 +1226,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
- typeExtensions.clear();
- cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
- if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
- typeExtensions)))
- return false;
-
typeCapabilities.clear();
cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
- if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
- typeCapabilities)))
+ typeExtensions.clear();
+ cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
+ // Checking for capability and extension requirements along with capability
+ // infered extensions.
+ // If a capability is present, the extension that
+ // supports it should also be present, this reduces the burden of adding
+ // extension requirement that may or maynot be added in
+ // CompositeType::getExtensions().
+ if (failed(checkCapabilityAndExtensionRequirements(
+ op->getName(), this->targetEnv, typeCapabilities, typeExtensions)))
return false;
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/68033
More information about the Mlir-commits
mailing list