[Mlir-commits] [mlir] e5c85a5 - [mlir][spirv] Support querying type extension/capability requirements
Lei Zhang
llvmlistbot at llvm.org
Thu Mar 12 16:40:00 PDT 2020
Author: Lei Zhang
Date: 2020-03-12T19:37:45-04:00
New Revision: e5c85a5a4ffaa0ff55f6d1d80a4e47f96ec0b9de
URL: https://github.com/llvm/llvm-project/commit/e5c85a5a4ffaa0ff55f6d1d80a4e47f96ec0b9de
DIFF: https://github.com/llvm/llvm-project/commit/e5c85a5a4ffaa0ff55f6d1d80a4e47f96ec0b9de.diff
LOG: [mlir][spirv] Support querying type extension/capability requirements
Previously we only consider the version/capability/extension requirements
on ops themselves. Some types in SPIR-V also require special extensions
or capabilities to be used. For example, non-32-bit integers/floats
will require different capabilities and/or extensions depending on
where they are used because it may mean special hardware abilities.
This commit adds query methods to SPIR-V type class hierarchy to support
querying extensions and capabilities. We don't go through ODS for
auto-generating such information given that we don't have them in
SPIR-V machine readable grammar and there are just a few types.
Differential Revision: https://reviews.llvm.org/D75875
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index faedfb3993cb..aad8b9f2ec7e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3000,7 +3000,7 @@ def SPV_AnyStruct : Type<SPV_IsStructType, "any SPIR-V struct type">;
def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
-def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyStruct]>;
+def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
def SPV_Composite :
AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
def SPV_Type : AnyTypeOf<[
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index 73dd142aeed2..fa9e94046c46 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -71,16 +71,65 @@ enum Kind {
};
}
-// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
-class CompositeType : public Type {
+// Base SPIR-V type for providing availability queries.
+class SPIRVType : public Type {
public:
using Type::Type;
static bool classof(Type type);
+ /// The extension requirements for each type are following the
+ /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
+ /// convention.
+ using ExtensionArrayRefVector = SmallVectorImpl<ArrayRef<spirv::Extension>>;
+
+ /// Appends to `extensions` the extensions needed for this type to appear in
+ /// the given `storage` class. This method does not guarantee the uniqueness
+ /// of extensions; the same extension may be appended multiple times.
+ void getExtensions(ExtensionArrayRefVector &extensions,
+ Optional<spirv::StorageClass> storage = llvm::None);
+
+ /// The capability requirements for each type are following the
+ /// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
+ /// convention.
+ using CapabilityArrayRefVector = SmallVectorImpl<ArrayRef<spirv::Capability>>;
+
+ /// Appends to `capabilities` the capabilities needed for this type to appear
+ /// in the given `storage` class. This method does not guarantee the
+ /// uniqueness of capabilities; the same capability may be appended multiple
+ /// times.
+ void getCapabilities(CapabilityArrayRefVector &capabilities,
+ Optional<spirv::StorageClass> storage = llvm::None);
+};
+
+// SPIR-V scalar type: bool type, integer type, floating point type.
+class ScalarType : public SPIRVType {
+public:
+ using SPIRVType::SPIRVType;
+
+ static bool classof(Type type);
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<spirv::StorageClass> storage = llvm::None);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<spirv::StorageClass> storage = llvm::None);
+};
+
+// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
+class CompositeType : public SPIRVType {
+public:
+ using SPIRVType::SPIRVType;
+
+ static bool classof(Type type);
+
unsigned getNumElements() const;
Type getElementType(unsigned) const;
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<spirv::StorageClass> storage = llvm::None);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<spirv::StorageClass> storage = llvm::None);
};
// SPIR-V array type
@@ -105,11 +154,16 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
bool hasLayout() const;
uint64_t getArrayStride() const;
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<spirv::StorageClass> storage = llvm::None);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<spirv::StorageClass> storage = llvm::None);
};
// SPIR-V image type
class ImageType
- : public Type::TypeBase<ImageType, Type, detail::ImageTypeStorage> {
+ : public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
public:
using Base::Base;
@@ -141,11 +195,16 @@ class ImageType
ImageSamplerUseInfo getSamplerUseInfo() const;
ImageFormat getImageFormat() const;
// TODO(ravishankarm): Add support for Access qualifier
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<spirv::StorageClass> storage = llvm::None);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<spirv::StorageClass> storage = llvm::None);
};
// SPIR-V pointer type
-class PointerType
- : public Type::TypeBase<PointerType, Type, detail::PointerTypeStorage> {
+class PointerType : public Type::TypeBase<PointerType, SPIRVType,
+ detail::PointerTypeStorage> {
public:
using Base::Base;
@@ -156,11 +215,16 @@ class PointerType
Type getPointeeType() const;
StorageClass getStorageClass() const;
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<spirv::StorageClass> storage = llvm::None);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<spirv::StorageClass> storage = llvm::None);
};
// SPIR-V run-time array type
class RuntimeArrayType
- : public Type::TypeBase<RuntimeArrayType, Type,
+ : public Type::TypeBase<RuntimeArrayType, SPIRVType,
detail::RuntimeArrayTypeStorage> {
public:
using Base::Base;
@@ -170,6 +234,11 @@ class RuntimeArrayType
static RuntimeArrayType get(Type elementType);
Type getElementType() const;
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<spirv::StorageClass> storage = llvm::None);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<spirv::StorageClass> storage = llvm::None);
};
// SPIR-V struct type
@@ -203,6 +272,31 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
Type getElementType(unsigned) const;
+ /// Range class for element types.
+ class ElementTypeRange
+ : public ::mlir::detail::indexed_accessor_range_base<
+ ElementTypeRange, const Type *, Type, Type, Type> {
+ private:
+ using RangeBaseT::RangeBaseT;
+
+ /// See `mlir::detail::indexed_accessor_range_base` for details.
+ static const Type *offset_base(const Type *object, ptr
diff _t index) {
+ return object + index;
+ }
+ /// See `mlir::detail::indexed_accessor_range_base` for details.
+ static Type dereference_iterator(const Type *object, ptr
diff _t index) {
+ return object[index];
+ }
+
+ /// Allow StructType class access to constructors.
+ friend class ElementTypeRange;
+
+ /// Allow base class access to `offset_base` and `dereference_iterator`.
+ friend RangeBaseT;
+ };
+
+ ElementTypeRange getElementTypes() const;
+
bool hasLayout() const;
uint64_t getOffset(unsigned) const;
@@ -216,6 +310,11 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
// Offset) associated with the `i`-th member of the StructType.
void getMemberDecorations(
unsigned i, SmallVectorImpl<spirv::Decoration> &memberDecorations) const;
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<spirv::StorageClass> storage = llvm::None);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<spirv::StorageClass> storage = llvm::None);
};
} // end namespace spirv
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 468847dc5d0d..92dc5b82bb8a 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -57,6 +57,7 @@ ArrayRef<Extension> spirv::getImpliedExtensions(Version version) {
default:
return {};
case Version::V_1_3: {
+ // The following manual ArrayRef constructor call is to satisfy GCC 5.
static const Extension exts[] = {V_1_3_IMPLIED_EXTS};
return ArrayRef<Extension>(exts, llvm::array_lengthof(exts));
}
@@ -142,6 +143,17 @@ bool ArrayType::hasLayout() const { return getImpl()->layoutInfo; }
uint64_t ArrayType::getArrayStride() const { return getImpl()->layoutInfo; }
+void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<StorageClass> storage) {
+ getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+}
+
+void ArrayType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<StorageClass> storage) {
+ getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+}
+
//===----------------------------------------------------------------------===//
// CompositeType
//===----------------------------------------------------------------------===//
@@ -189,6 +201,50 @@ unsigned CompositeType::getNumElements() const {
}
}
+void CompositeType::getExtensions(
+ SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<StorageClass> storage) {
+ switch (getKind()) {
+ case spirv::TypeKind::Array:
+ cast<ArrayType>().getExtensions(extensions, storage);
+ break;
+ case spirv::TypeKind::RuntimeArray:
+ cast<ArrayType>().getExtensions(extensions, storage);
+ break;
+ case spirv::TypeKind::Struct:
+ cast<StructType>().getExtensions(extensions, storage);
+ break;
+ case StandardTypes::Vector:
+ cast<VectorType>().getElementType().cast<ScalarType>().getExtensions(
+ extensions, storage);
+ break;
+ default:
+ llvm_unreachable("invalid composite type");
+ }
+}
+
+void CompositeType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<StorageClass> storage) {
+ switch (getKind()) {
+ case spirv::TypeKind::Array:
+ cast<ArrayType>().getCapabilities(capabilities, storage);
+ break;
+ case spirv::TypeKind::RuntimeArray:
+ cast<ArrayType>().getCapabilities(capabilities, storage);
+ break;
+ case spirv::TypeKind::Struct:
+ cast<StructType>().getCapabilities(capabilities, storage);
+ break;
+ case StandardTypes::Vector:
+ cast<VectorType>().getElementType().cast<ScalarType>().getCapabilities(
+ capabilities, storage);
+ break;
+ default:
+ llvm_unreachable("invalid composite type");
+ }
+}
+
//===----------------------------------------------------------------------===//
// ImageType
//===----------------------------------------------------------------------===//
@@ -372,6 +428,20 @@ ImageFormat ImageType::getImageFormat() const {
return getImpl()->getImageFormat();
}
+void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
+ Optional<StorageClass>) {
+ // Image types do not require extra extensions thus far.
+}
+
+void ImageType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities, Optional<StorageClass>) {
+ if (auto dimCaps = spirv::getCapabilities(getDim()))
+ capabilities.push_back(*dimCaps);
+
+ if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
+ capabilities.push_back(*fmtCaps);
+}
+
//===----------------------------------------------------------------------===//
// PointerType
//===----------------------------------------------------------------------===//
@@ -413,6 +483,35 @@ StorageClass PointerType::getStorageClass() const {
return getImpl()->getStorageClass();
}
+void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<StorageClass> storage) {
+ if (storage)
+ assert(*storage == getStorageClass() && "inconsistent storage class!");
+
+ // Use this pointer type's storage class because this pointer indicates we are
+ // using the pointee type in that specific storage class.
+ getPointeeType().cast<SPIRVType>().getExtensions(extensions,
+ getStorageClass());
+
+ if (auto scExts = spirv::getExtensions(getStorageClass()))
+ extensions.push_back(*scExts);
+}
+
+void PointerType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<StorageClass> storage) {
+ if (storage)
+ assert(*storage == getStorageClass() && "inconsistent storage class!");
+
+ // Use this pointer type's storage class because this pointer indicates we are
+ // using the pointee type in that specific storage class.
+ getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
+ getStorageClass());
+
+ if (auto scCaps = spirv::getCapabilities(getStorageClass()))
+ capabilities.push_back(*scCaps);
+}
+
//===----------------------------------------------------------------------===//
// RuntimeArrayType
//===----------------------------------------------------------------------===//
@@ -440,6 +539,181 @@ RuntimeArrayType RuntimeArrayType::get(Type elementType) {
Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
+void RuntimeArrayType::getExtensions(
+ SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<StorageClass> storage) {
+ getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+}
+
+void RuntimeArrayType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<StorageClass> storage) {
+ {
+ static const Capability caps[] = {Capability::Shader};
+ ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
+ capabilities.push_back(ref);
+ }
+ getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+}
+
+//===----------------------------------------------------------------------===//
+// ScalarType
+//===----------------------------------------------------------------------===//
+
+bool ScalarType::classof(Type type) { return type.isIntOrFloat(); }
+
+void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<StorageClass> storage) {
+ // 8- or 16-bit integer/floating-point numbers will require extra extensions
+ // to appear in interface storage classes. See SPV_KHR_16bit_storage and
+ // SPV_KHR_8bit_storage for more details.
+ if (!storage)
+ return;
+
+ switch (*storage) {
+ case StorageClass::PushConstant:
+ case StorageClass::StorageBuffer:
+ case StorageClass::Uniform:
+ if (getIntOrFloatBitWidth() == 8) {
+ static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
+ ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
+ extensions.push_back(ref);
+ }
+ LLVM_FALLTHROUGH;
+ case StorageClass::Input:
+ case StorageClass::Output:
+ if (getIntOrFloatBitWidth() == 16) {
+ static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
+ ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
+ extensions.push_back(ref);
+ }
+ break;
+ default:
+ break;
+ }
+}
+
+void ScalarType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<StorageClass> storage) {
+ unsigned bitwidth = getIntOrFloatBitWidth();
+
+ // 8- or 16-bit integer/floating-point numbers will require extra capabilities
+ // to appear in interface storage classes. See SPV_KHR_16bit_storage and
+ // SPV_KHR_8bit_storage for more details.
+
+#define STORAGE_CASE(storage, cap8, cap16) \
+ case StorageClass::storage: { \
+ if (bitwidth == 8) { \
+ static const Capability caps[] = {Capability::cap8}; \
+ ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
+ capabilities.push_back(ref); \
+ } else if (bitwidth == 16) { \
+ static const Capability caps[] = {Capability::cap16}; \
+ ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
+ capabilities.push_back(ref); \
+ } \
+ } break
+
+ if (storage) {
+ switch (*storage) {
+ STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
+ STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
+ StorageBuffer16BitAccess);
+ STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
+ StorageUniform16);
+ case StorageClass::Input:
+ case StorageClass::Output:
+ if (bitwidth == 16) {
+ static const Capability caps[] = {Capability::StorageInputOutput16};
+ ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
+ capabilities.push_back(ref);
+ }
+ break;
+ default:
+ break;
+ }
+ return;
+ }
+#undef STORAGE_CASE
+
+ // For other non-interface storage classes, require a
diff erent set of
+ // capabilities for special bitwidths.
+
+#define WIDTH_CASE(type, width) \
+ case width: { \
+ static const Capability caps[] = {Capability::type##width}; \
+ ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
+ capabilities.push_back(ref); \
+ } break
+
+ if (auto intType = dyn_cast<IntegerType>()) {
+ switch (bitwidth) {
+ case 32:
+ case 1:
+ break;
+ WIDTH_CASE(Int, 8);
+ WIDTH_CASE(Int, 16);
+ WIDTH_CASE(Int, 64);
+ default:
+ llvm_unreachable("invalid bitwidth to getCapabilities");
+ }
+ } else {
+ assert(isa<FloatType>());
+ switch (bitwidth) {
+ case 32:
+ break;
+ WIDTH_CASE(Float, 16);
+ WIDTH_CASE(Float, 64);
+ default:
+ llvm_unreachable("invalid bitwidth to getCapabilities");
+ }
+ }
+
+#undef WIDTH_CASE
+}
+
+//===----------------------------------------------------------------------===//
+// SPIRVType
+//===----------------------------------------------------------------------===//
+
+bool SPIRVType::classof(Type type) {
+ return type.isa<ScalarType>() || type.isa<VectorType>() ||
+ (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
+ type.getKind() <= TypeKind::LAST_SPIRV_TYPE);
+}
+
+void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<StorageClass> storage) {
+ if (auto scalarType = dyn_cast<ScalarType>()) {
+ scalarType.getExtensions(extensions, storage);
+ } else if (auto compositeType = dyn_cast<CompositeType>()) {
+ compositeType.getExtensions(extensions, storage);
+ } else if (auto ptrType = dyn_cast<PointerType>()) {
+ ptrType.getExtensions(extensions, storage);
+ } else if (auto imageType = dyn_cast<ImageType>()) {
+ imageType.getExtensions(extensions, storage);
+ } else {
+ llvm_unreachable("invalid SPIR-V Type to getExtensions");
+ }
+}
+
+void SPIRVType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<StorageClass> storage) {
+ if (auto scalarType = dyn_cast<ScalarType>()) {
+ scalarType.getCapabilities(capabilities, storage);
+ } else if (auto compositeType = dyn_cast<CompositeType>()) {
+ compositeType.getCapabilities(capabilities, storage);
+ } else if (auto ptrType = dyn_cast<PointerType>()) {
+ ptrType.getCapabilities(capabilities, storage);
+ } else if (auto imageType = dyn_cast<ImageType>()) {
+ imageType.getCapabilities(capabilities, storage);
+ } else {
+ llvm_unreachable("invalid SPIR-V Type to getCapabilities");
+ }
+}
+
//===----------------------------------------------------------------------===//
// StructType
//===----------------------------------------------------------------------===//
@@ -540,18 +814,18 @@ unsigned StructType::getNumElements() const {
}
Type StructType::getElementType(unsigned index) const {
- assert(
- getNumElements() > index &&
- "element index is more than number of members of the SPIR-V StructType");
+ assert(getNumElements() > index && "member index out of range");
return getImpl()->memberTypes[index];
}
+StructType::ElementTypeRange StructType::getElementTypes() const {
+ return ElementTypeRange(getImpl()->memberTypes, getNumElements());
+}
+
bool StructType::hasLayout() const { return getImpl()->layoutInfo; }
uint64_t StructType::getOffset(unsigned index) const {
- assert(
- getNumElements() > index &&
- "element index is more than number of members of the SPIR-V StructType");
+ assert(getNumElements() > index && "member index out of range");
return getImpl()->layoutInfo[index];
}
@@ -579,3 +853,16 @@ void StructType::getMemberDecorations(
}
}
}
+
+void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ Optional<StorageClass> storage) {
+ for (Type elementType : getElementTypes())
+ elementType.cast<SPIRVType>().getExtensions(extensions, storage);
+}
+
+void StructType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ Optional<StorageClass> storage) {
+ for (Type elementType : getElementTypes())
+ elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
+}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 6647431b70fc..fff15c185749 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -33,6 +33,71 @@ class UpdateVCEPass final
};
} // namespace
+/// Checks that `candidates` extension requirements are possible to be satisfied
+/// with the given `allowedExtensions` and updates `deducedExtensions` if so.
+/// Emits errors attaching to the given `op` on failures.
+///
+/// `candidates` is a vector of vector for extension requirements following
+/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
+/// convention.
+static LogicalResult checkAndUpdateExtensionRequirements(
+ Operation *op, const llvm::SmallSet<spirv::Extension, 4> &allowedExtensions,
+ const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
+ llvm::SetVector<spirv::Extension> &deducedExtensions) {
+ for (const auto &ors : candidates) {
+ auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) {
+ return allowedExtensions.count(ext);
+ });
+
+ if (chosen != ors.end()) {
+ deducedExtensions.insert(*chosen);
+ } else {
+ SmallVector<StringRef, 4> extStrings;
+ for (spirv::Extension ext : ors)
+ extStrings.push_back(spirv::stringifyExtension(ext));
+
+ return op->emitError("'")
+ << op->getName() << "' requires at least one extension in ["
+ << llvm::join(extStrings, ", ")
+ << "] but none allowed in target environment";
+ }
+ }
+ return success();
+}
+
+/// Checks that `candidates`capability requirements are possible to be satisfied
+/// with the given `allowedCapabilities` and updates `deducedCapabilities` if
+/// so. Emits errors attaching to the given `op` on failures.
+///
+/// `candidates` is a vector of vector for capability requirements following
+/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
+/// convention.
+static LogicalResult checkAndUpdateCapabilityRequirements(
+ Operation *op,
+ const llvm::SmallSet<spirv::Capability, 8> &allowedCapabilities,
+ const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
+ llvm::SetVector<spirv::Capability> &deducedCapabilities) {
+ for (const auto &ors : candidates) {
+ auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) {
+ return allowedCapabilities.count(cap);
+ });
+
+ if (chosen != ors.end()) {
+ deducedCapabilities.insert(*chosen);
+ } else {
+ SmallVector<StringRef, 4> capStrings;
+ for (spirv::Capability cap : ors)
+ capStrings.push_back(spirv::stringifyCapability(cap));
+
+ return op->emitError("'")
+ << op->getName() << "' requires at least one capability in ["
+ << llvm::join(capStrings, ", ")
+ << "] but none allowed in target environment";
+ }
+ }
+ return success();
+}
+
void UpdateVCEPass::runOnOperation() {
spirv::ModuleOp module = getOperation();
@@ -70,6 +135,7 @@ void UpdateVCEPass::runOnOperation() {
// Walk each SPIR-V op to deduce the minimal version/extension/capability
// requirements.
WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
+ // Op min version requirements
if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
deducedVersion = std::max(deducedVersion, minVersion.getMinVersion());
if (deducedVersion > allowedVersion) {
@@ -80,62 +146,44 @@ void UpdateVCEPass::runOnOperation() {
}
}
- // Deduce this op's extension requirement. For each op, the query interfacce
- // returns a vector of vector for its extension requirements following
- // ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
- // convention. Ops not implementing QueryExtensionInterface do not require
- // extensions to be available.
- if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) {
- for (const auto &ors : extensions.getExtensions()) {
- bool satisfied = false; // True when at least one extension can be used
- for (spirv::Extension ext : ors) {
- if (allowedExtensions.count(ext)) {
- deducedExtensions.insert(ext);
- satisfied = true;
- break;
- }
- }
-
- if (!satisfied) {
- SmallVector<StringRef, 4> extStrings;
- for (spirv::Extension ext : ors)
- extStrings.push_back(spirv::stringifyExtension(ext));
-
- return op->emitError("'")
- << op->getName() << "' requires at least one extension in ["
- << llvm::join(extStrings, ", ")
- << "] but none allowed in target environment";
- }
- }
- }
-
- // Deduce this op's capability requirement. For each op, the queryinterface
- // returns a vector of vector for its capability requirements following
- // ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
- // convention. Ops not implementing QueryExtensionInterface do not require
- // extensions to be available.
- if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
- for (const auto &ors : capabilities.getCapabilities()) {
- bool satisfied = false; // True when at least one capability can be used
- for (spirv::Capability cap : ors) {
- if (allowedCapabilities.count(cap)) {
- deducedCapabilities.insert(cap);
- satisfied = true;
- break;
- }
- }
-
- if (!satisfied) {
- SmallVector<StringRef, 4> capStrings;
- for (spirv::Capability cap : ors)
- capStrings.push_back(spirv::stringifyCapability(cap));
-
- return op->emitError("'")
- << op->getName() << "' requires at least one capability in ["
- << llvm::join(capStrings, ", ")
- << "] but none allowed in target environment";
- }
- }
+ // Op extension requirements
+ if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
+ if (failed(checkAndUpdateExtensionRequirements(op, allowedExtensions,
+ extensions.getExtensions(),
+ deducedExtensions)))
+ return WalkResult::interrupt();
+
+ // Op capability requirements
+ if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
+ if (failed(checkAndUpdateCapabilityRequirements(
+ op, allowedCapabilities, capabilities.getCapabilities(),
+ deducedCapabilities)))
+ return WalkResult::interrupt();
+
+ SmallVector<Type, 4> valueTypes;
+ valueTypes.append(op->operand_type_begin(), op->operand_type_end());
+ valueTypes.append(op->result_type_begin(), op->result_type_end());
+
+ // Special treatment for global variables, whose type requirements are
+ // conveyed by type attributes.
+ if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
+ valueTypes.push_back(globalVar.type());
+
+ // Requirements from values' types
+ SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
+ SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
+ for (Type valueType : valueTypes) {
+ typeExtensions.clear();
+ valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
+ if (failed(checkAndUpdateExtensionRequirements(
+ op, allowedExtensions, typeExtensions, deducedExtensions)))
+ return WalkResult::interrupt();
+
+ typeCapabilities.clear();
+ valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
+ if (failed(checkAndUpdateCapabilityRequirements(
+ op, allowedCapabilities, typeCapabilities, deducedCapabilities)))
+ return WalkResult::interrupt();
}
return WalkResult::advance();
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 60bf13e2571e..572db88e5f9e 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -107,6 +107,36 @@ spv.module Logical GLSL450 attributes {
}
}
+// Test type required capabilities
+
+// Using 8-bit integers in non-interface storage class requires Int8.
+// CHECK: requires #spv.vce<v1.0, [Int8, Shader], []>
+spv.module Logical GLSL450 attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.3, [Shader, Int8], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+ spv.func @iadd_function(%val : i8) -> i8 "None" {
+ %0 = spv.IAdd %val, %val : i8
+ spv.ReturnValue %0: i8
+ }
+}
+
+// Using 16-bit floats in non-interface storage class requires Float16.
+// CHECK: requires #spv.vce<v1.0, [Float16, Shader], []>
+spv.module Logical GLSL450 attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.3, [Shader, Float16], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+ spv.func @fadd_function(%val : f16) -> f16 "None" {
+ %0 = spv.FAdd %val, %val : f16
+ spv.ReturnValue %0: f16
+ }
+}
+
//===----------------------------------------------------------------------===//
// Extension
//===----------------------------------------------------------------------===//
@@ -144,3 +174,35 @@ spv.module Logical Vulkan attributes {
spv.ReturnValue %0: i32
}
}
+
+// Test type required extensions
+
+// Using 8-bit integers in interface storage class requires additional
+// extensions and capabilities.
+// CHECK: requires #spv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+spv.module Logical GLSL450 attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.3, [Shader, StorageBuffer16BitAccess, Int16], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+ spv.func @iadd_storage_buffer(%ptr : !spv.ptr<i16, StorageBuffer>) -> i16 "None" {
+ %0 = spv.Load "StorageBuffer" %ptr : i16
+ %1 = spv.IAdd %0, %0 : i16
+ spv.ReturnValue %1: i16
+ }
+}
+
+// Complicated nested types
+// * Buffer requires ImageBuffer or SampledBuffer.
+// * Rg32f requires StorageImageExtendedFormats.
+// CHECK: requires #spv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Shader, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
+spv.module Logical GLSL450 attributes {
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, ImageBuffer, StorageImageExtendedFormats], []>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+ spv.globalVariable @data : !spv.ptr<!spv.struct<i8 [0], f16 [2], i64 [4]>, Uniform>
+ spv.globalVariable @img : !spv.ptr<!spv.image<f32, Buffer, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Rg32f>, UniformConstant>
+}
More information about the Mlir-commits
mailing list