[Mlir-commits] [mlir] e5c85a5 - [mlir][spirv] Support querying type extension/capability requirements

Lei Zhang llvmlistbot at llvm.org
Thu Mar 12 16:40:00 PDT 2020


Author: Lei Zhang
Date: 2020-03-12T19:37:45-04:00
New Revision: e5c85a5a4ffaa0ff55f6d1d80a4e47f96ec0b9de

URL: https://github.com/llvm/llvm-project/commit/e5c85a5a4ffaa0ff55f6d1d80a4e47f96ec0b9de
DIFF: https://github.com/llvm/llvm-project/commit/e5c85a5a4ffaa0ff55f6d1d80a4e47f96ec0b9de.diff

LOG: [mlir][spirv] Support querying type extension/capability requirements

Previously we only consider the version/capability/extension requirements
on ops themselves. Some types in SPIR-V also require special extensions
or capabilities to be used. For example, non-32-bit integers/floats
will require different capabilities and/or extensions depending on
where they are used because it may mean special hardware abilities.

This commit adds query methods to SPIR-V type class hierarchy to support
querying extensions and capabilities. We don't go through ODS for
auto-generating such information given that we don't have them in
SPIR-V machine readable grammar and there are just a few types.

Differential Revision: https://reviews.llvm.org/D75875

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
    mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
    mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
    mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index faedfb3993cb..aad8b9f2ec7e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3000,7 +3000,7 @@ def SPV_AnyStruct : Type<SPV_IsStructType, "any SPIR-V struct type">;
 
 def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
 def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
-def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyStruct]>;
+def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
 def SPV_Composite :
     AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>;
 def SPV_Type : AnyTypeOf<[

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
index 73dd142aeed2..fa9e94046c46 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
@@ -71,16 +71,65 @@ enum Kind {
 };
 }
 
-// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
-class CompositeType : public Type {
+// Base SPIR-V type for providing availability queries.
+class SPIRVType : public Type {
 public:
   using Type::Type;
 
   static bool classof(Type type);
 
+  /// The extension requirements for each type are following the
+  /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
+  /// convention.
+  using ExtensionArrayRefVector = SmallVectorImpl<ArrayRef<spirv::Extension>>;
+
+  /// Appends to `extensions` the extensions needed for this type to appear in
+  /// the given `storage` class. This method does not guarantee the uniqueness
+  /// of extensions; the same extension may be appended multiple times.
+  void getExtensions(ExtensionArrayRefVector &extensions,
+                     Optional<spirv::StorageClass> storage = llvm::None);
+
+  /// The capability requirements for each type are following the
+  /// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
+  /// convention.
+  using CapabilityArrayRefVector = SmallVectorImpl<ArrayRef<spirv::Capability>>;
+
+  /// Appends to `capabilities` the capabilities needed for this type to appear
+  /// in the given `storage` class. This method does not guarantee the
+  /// uniqueness of capabilities; the same capability may be appended multiple
+  /// times.
+  void getCapabilities(CapabilityArrayRefVector &capabilities,
+                       Optional<spirv::StorageClass> storage = llvm::None);
+};
+
+// SPIR-V scalar type: bool type, integer type, floating point type.
+class ScalarType : public SPIRVType {
+public:
+  using SPIRVType::SPIRVType;
+
+  static bool classof(Type type);
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     Optional<spirv::StorageClass> storage = llvm::None);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       Optional<spirv::StorageClass> storage = llvm::None);
+};
+
+// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
+class CompositeType : public SPIRVType {
+public:
+  using SPIRVType::SPIRVType;
+
+  static bool classof(Type type);
+
   unsigned getNumElements() const;
 
   Type getElementType(unsigned) const;
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     Optional<spirv::StorageClass> storage = llvm::None);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       Optional<spirv::StorageClass> storage = llvm::None);
 };
 
 // SPIR-V array type
@@ -105,11 +154,16 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
   bool hasLayout() const;
 
   uint64_t getArrayStride() const;
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     Optional<spirv::StorageClass> storage = llvm::None);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       Optional<spirv::StorageClass> storage = llvm::None);
 };
 
 // SPIR-V image type
 class ImageType
-    : public Type::TypeBase<ImageType, Type, detail::ImageTypeStorage> {
+    : public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> {
 public:
   using Base::Base;
 
@@ -141,11 +195,16 @@ class ImageType
   ImageSamplerUseInfo getSamplerUseInfo() const;
   ImageFormat getImageFormat() const;
   // TODO(ravishankarm): Add support for Access qualifier
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     Optional<spirv::StorageClass> storage = llvm::None);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       Optional<spirv::StorageClass> storage = llvm::None);
 };
 
 // SPIR-V pointer type
-class PointerType
-    : public Type::TypeBase<PointerType, Type, detail::PointerTypeStorage> {
+class PointerType : public Type::TypeBase<PointerType, SPIRVType,
+                                          detail::PointerTypeStorage> {
 public:
   using Base::Base;
 
@@ -156,11 +215,16 @@ class PointerType
   Type getPointeeType() const;
 
   StorageClass getStorageClass() const;
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     Optional<spirv::StorageClass> storage = llvm::None);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       Optional<spirv::StorageClass> storage = llvm::None);
 };
 
 // SPIR-V run-time array type
 class RuntimeArrayType
-    : public Type::TypeBase<RuntimeArrayType, Type,
+    : public Type::TypeBase<RuntimeArrayType, SPIRVType,
                             detail::RuntimeArrayTypeStorage> {
 public:
   using Base::Base;
@@ -170,6 +234,11 @@ class RuntimeArrayType
   static RuntimeArrayType get(Type elementType);
 
   Type getElementType() const;
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     Optional<spirv::StorageClass> storage = llvm::None);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       Optional<spirv::StorageClass> storage = llvm::None);
 };
 
 // SPIR-V struct type
@@ -203,6 +272,31 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
 
   Type getElementType(unsigned) const;
 
+  /// Range class for element types.
+  class ElementTypeRange
+      : public ::mlir::detail::indexed_accessor_range_base<
+            ElementTypeRange, const Type *, Type, Type, Type> {
+  private:
+    using RangeBaseT::RangeBaseT;
+
+    /// See `mlir::detail::indexed_accessor_range_base` for details.
+    static const Type *offset_base(const Type *object, ptr
diff _t index) {
+      return object + index;
+    }
+    /// See `mlir::detail::indexed_accessor_range_base` for details.
+    static Type dereference_iterator(const Type *object, ptr
diff _t index) {
+      return object[index];
+    }
+
+    /// Allow StructType class access to constructors.
+    friend class ElementTypeRange;
+
+    /// Allow base class access to `offset_base` and `dereference_iterator`.
+    friend RangeBaseT;
+  };
+
+  ElementTypeRange getElementTypes() const;
+
   bool hasLayout() const;
 
   uint64_t getOffset(unsigned) const;
@@ -216,6 +310,11 @@ class StructType : public Type::TypeBase<StructType, CompositeType,
   // Offset) associated with the `i`-th member of the StructType.
   void getMemberDecorations(
       unsigned i, SmallVectorImpl<spirv::Decoration> &memberDecorations) const;
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     Optional<spirv::StorageClass> storage = llvm::None);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       Optional<spirv::StorageClass> storage = llvm::None);
 };
 
 } // end namespace spirv

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
index 468847dc5d0d..92dc5b82bb8a 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp
@@ -57,6 +57,7 @@ ArrayRef<Extension> spirv::getImpliedExtensions(Version version) {
   default:
     return {};
   case Version::V_1_3: {
+    // The following manual ArrayRef constructor call is to satisfy GCC 5.
     static const Extension exts[] = {V_1_3_IMPLIED_EXTS};
     return ArrayRef<Extension>(exts, llvm::array_lengthof(exts));
   }
@@ -142,6 +143,17 @@ bool ArrayType::hasLayout() const { return getImpl()->layoutInfo; }
 
 uint64_t ArrayType::getArrayStride() const { return getImpl()->layoutInfo; }
 
+void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                              Optional<StorageClass> storage) {
+  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+}
+
+void ArrayType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    Optional<StorageClass> storage) {
+  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+}
+
 //===----------------------------------------------------------------------===//
 // CompositeType
 //===----------------------------------------------------------------------===//
@@ -189,6 +201,50 @@ unsigned CompositeType::getNumElements() const {
   }
 }
 
+void CompositeType::getExtensions(
+    SPIRVType::ExtensionArrayRefVector &extensions,
+    Optional<StorageClass> storage) {
+  switch (getKind()) {
+  case spirv::TypeKind::Array:
+    cast<ArrayType>().getExtensions(extensions, storage);
+    break;
+  case spirv::TypeKind::RuntimeArray:
+    cast<ArrayType>().getExtensions(extensions, storage);
+    break;
+  case spirv::TypeKind::Struct:
+    cast<StructType>().getExtensions(extensions, storage);
+    break;
+  case StandardTypes::Vector:
+    cast<VectorType>().getElementType().cast<ScalarType>().getExtensions(
+        extensions, storage);
+    break;
+  default:
+    llvm_unreachable("invalid composite type");
+  }
+}
+
+void CompositeType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    Optional<StorageClass> storage) {
+  switch (getKind()) {
+  case spirv::TypeKind::Array:
+    cast<ArrayType>().getCapabilities(capabilities, storage);
+    break;
+  case spirv::TypeKind::RuntimeArray:
+    cast<ArrayType>().getCapabilities(capabilities, storage);
+    break;
+  case spirv::TypeKind::Struct:
+    cast<StructType>().getCapabilities(capabilities, storage);
+    break;
+  case StandardTypes::Vector:
+    cast<VectorType>().getElementType().cast<ScalarType>().getCapabilities(
+        capabilities, storage);
+    break;
+  default:
+    llvm_unreachable("invalid composite type");
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // ImageType
 //===----------------------------------------------------------------------===//
@@ -372,6 +428,20 @@ ImageFormat ImageType::getImageFormat() const {
   return getImpl()->getImageFormat();
 }
 
+void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
+                              Optional<StorageClass>) {
+  // Image types do not require extra extensions thus far.
+}
+
+void ImageType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities, Optional<StorageClass>) {
+  if (auto dimCaps = spirv::getCapabilities(getDim()))
+    capabilities.push_back(*dimCaps);
+
+  if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
+    capabilities.push_back(*fmtCaps);
+}
+
 //===----------------------------------------------------------------------===//
 // PointerType
 //===----------------------------------------------------------------------===//
@@ -413,6 +483,35 @@ StorageClass PointerType::getStorageClass() const {
   return getImpl()->getStorageClass();
 }
 
+void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                                Optional<StorageClass> storage) {
+  if (storage)
+    assert(*storage == getStorageClass() && "inconsistent storage class!");
+
+  // Use this pointer type's storage class because this pointer indicates we are
+  // using the pointee type in that specific storage class.
+  getPointeeType().cast<SPIRVType>().getExtensions(extensions,
+                                                   getStorageClass());
+
+  if (auto scExts = spirv::getExtensions(getStorageClass()))
+    extensions.push_back(*scExts);
+}
+
+void PointerType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    Optional<StorageClass> storage) {
+  if (storage)
+    assert(*storage == getStorageClass() && "inconsistent storage class!");
+
+  // Use this pointer type's storage class because this pointer indicates we are
+  // using the pointee type in that specific storage class.
+  getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
+                                                     getStorageClass());
+
+  if (auto scCaps = spirv::getCapabilities(getStorageClass()))
+    capabilities.push_back(*scCaps);
+}
+
 //===----------------------------------------------------------------------===//
 // RuntimeArrayType
 //===----------------------------------------------------------------------===//
@@ -440,6 +539,181 @@ RuntimeArrayType RuntimeArrayType::get(Type elementType) {
 
 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
 
+void RuntimeArrayType::getExtensions(
+    SPIRVType::ExtensionArrayRefVector &extensions,
+    Optional<StorageClass> storage) {
+  getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
+}
+
+void RuntimeArrayType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    Optional<StorageClass> storage) {
+  {
+    static const Capability caps[] = {Capability::Shader};
+    ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
+    capabilities.push_back(ref);
+  }
+  getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
+}
+
+//===----------------------------------------------------------------------===//
+// ScalarType
+//===----------------------------------------------------------------------===//
+
+bool ScalarType::classof(Type type) { return type.isIntOrFloat(); }
+
+void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                               Optional<StorageClass> storage) {
+  // 8- or 16-bit integer/floating-point numbers will require extra extensions
+  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
+  // SPV_KHR_8bit_storage for more details.
+  if (!storage)
+    return;
+
+  switch (*storage) {
+  case StorageClass::PushConstant:
+  case StorageClass::StorageBuffer:
+  case StorageClass::Uniform:
+    if (getIntOrFloatBitWidth() == 8) {
+      static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
+      ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
+      extensions.push_back(ref);
+    }
+    LLVM_FALLTHROUGH;
+  case StorageClass::Input:
+  case StorageClass::Output:
+    if (getIntOrFloatBitWidth() == 16) {
+      static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
+      ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
+      extensions.push_back(ref);
+    }
+    break;
+  default:
+    break;
+  }
+}
+
+void ScalarType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    Optional<StorageClass> storage) {
+  unsigned bitwidth = getIntOrFloatBitWidth();
+
+  // 8- or 16-bit integer/floating-point numbers will require extra capabilities
+  // to appear in interface storage classes. See SPV_KHR_16bit_storage and
+  // SPV_KHR_8bit_storage for more details.
+
+#define STORAGE_CASE(storage, cap8, cap16)                                     \
+  case StorageClass::storage: {                                                \
+    if (bitwidth == 8) {                                                       \
+      static const Capability caps[] = {Capability::cap8};                     \
+      ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));              \
+      capabilities.push_back(ref);                                             \
+    } else if (bitwidth == 16) {                                               \
+      static const Capability caps[] = {Capability::cap16};                    \
+      ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));              \
+      capabilities.push_back(ref);                                             \
+    }                                                                          \
+  } break
+
+  if (storage) {
+    switch (*storage) {
+      STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
+      STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
+                   StorageBuffer16BitAccess);
+      STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
+                   StorageUniform16);
+    case StorageClass::Input:
+    case StorageClass::Output:
+      if (bitwidth == 16) {
+        static const Capability caps[] = {Capability::StorageInputOutput16};
+        ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
+        capabilities.push_back(ref);
+      }
+      break;
+    default:
+      break;
+    }
+    return;
+  }
+#undef STORAGE_CASE
+
+  // For other non-interface storage classes, require a 
diff erent set of
+  // capabilities for special bitwidths.
+
+#define WIDTH_CASE(type, width)                                                \
+  case width: {                                                                \
+    static const Capability caps[] = {Capability::type##width};                \
+    ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));                \
+    capabilities.push_back(ref);                                               \
+  } break
+
+  if (auto intType = dyn_cast<IntegerType>()) {
+    switch (bitwidth) {
+    case 32:
+    case 1:
+      break;
+      WIDTH_CASE(Int, 8);
+      WIDTH_CASE(Int, 16);
+      WIDTH_CASE(Int, 64);
+    default:
+      llvm_unreachable("invalid bitwidth to getCapabilities");
+    }
+  } else {
+    assert(isa<FloatType>());
+    switch (bitwidth) {
+    case 32:
+      break;
+      WIDTH_CASE(Float, 16);
+      WIDTH_CASE(Float, 64);
+    default:
+      llvm_unreachable("invalid bitwidth to getCapabilities");
+    }
+  }
+
+#undef WIDTH_CASE
+}
+
+//===----------------------------------------------------------------------===//
+// SPIRVType
+//===----------------------------------------------------------------------===//
+
+bool SPIRVType::classof(Type type) {
+  return type.isa<ScalarType>() || type.isa<VectorType>() ||
+         (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
+          type.getKind() <= TypeKind::LAST_SPIRV_TYPE);
+}
+
+void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                              Optional<StorageClass> storage) {
+  if (auto scalarType = dyn_cast<ScalarType>()) {
+    scalarType.getExtensions(extensions, storage);
+  } else if (auto compositeType = dyn_cast<CompositeType>()) {
+    compositeType.getExtensions(extensions, storage);
+  } else if (auto ptrType = dyn_cast<PointerType>()) {
+    ptrType.getExtensions(extensions, storage);
+  } else if (auto imageType = dyn_cast<ImageType>()) {
+    imageType.getExtensions(extensions, storage);
+  } else {
+    llvm_unreachable("invalid SPIR-V Type to getExtensions");
+  }
+}
+
+void SPIRVType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    Optional<StorageClass> storage) {
+  if (auto scalarType = dyn_cast<ScalarType>()) {
+    scalarType.getCapabilities(capabilities, storage);
+  } else if (auto compositeType = dyn_cast<CompositeType>()) {
+    compositeType.getCapabilities(capabilities, storage);
+  } else if (auto ptrType = dyn_cast<PointerType>()) {
+    ptrType.getCapabilities(capabilities, storage);
+  } else if (auto imageType = dyn_cast<ImageType>()) {
+    imageType.getCapabilities(capabilities, storage);
+  } else {
+    llvm_unreachable("invalid SPIR-V Type to getCapabilities");
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // StructType
 //===----------------------------------------------------------------------===//
@@ -540,18 +814,18 @@ unsigned StructType::getNumElements() const {
 }
 
 Type StructType::getElementType(unsigned index) const {
-  assert(
-      getNumElements() > index &&
-      "element index is more than number of members of the SPIR-V StructType");
+  assert(getNumElements() > index && "member index out of range");
   return getImpl()->memberTypes[index];
 }
 
+StructType::ElementTypeRange StructType::getElementTypes() const {
+  return ElementTypeRange(getImpl()->memberTypes, getNumElements());
+}
+
 bool StructType::hasLayout() const { return getImpl()->layoutInfo; }
 
 uint64_t StructType::getOffset(unsigned index) const {
-  assert(
-      getNumElements() > index &&
-      "element index is more than number of members of the SPIR-V StructType");
+  assert(getNumElements() > index && "member index out of range");
   return getImpl()->layoutInfo[index];
 }
 
@@ -579,3 +853,16 @@ void StructType::getMemberDecorations(
     }
   }
 }
+
+void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                               Optional<StorageClass> storage) {
+  for (Type elementType : getElementTypes())
+    elementType.cast<SPIRVType>().getExtensions(extensions, storage);
+}
+
+void StructType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    Optional<StorageClass> storage) {
+  for (Type elementType : getElementTypes())
+    elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
+}

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 6647431b70fc..fff15c185749 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -33,6 +33,71 @@ class UpdateVCEPass final
 };
 } // namespace
 
+/// Checks that `candidates` extension requirements are possible to be satisfied
+/// with the given `allowedExtensions` and updates `deducedExtensions` if so.
+/// Emits errors attaching to the given `op` on failures.
+///
+///  `candidates` is a vector of vector for extension requirements following
+/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
+/// convention.
+static LogicalResult checkAndUpdateExtensionRequirements(
+    Operation *op, const llvm::SmallSet<spirv::Extension, 4> &allowedExtensions,
+    const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
+    llvm::SetVector<spirv::Extension> &deducedExtensions) {
+  for (const auto &ors : candidates) {
+    auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) {
+      return allowedExtensions.count(ext);
+    });
+
+    if (chosen != ors.end()) {
+      deducedExtensions.insert(*chosen);
+    } else {
+      SmallVector<StringRef, 4> extStrings;
+      for (spirv::Extension ext : ors)
+        extStrings.push_back(spirv::stringifyExtension(ext));
+
+      return op->emitError("'")
+             << op->getName() << "' requires at least one extension in ["
+             << llvm::join(extStrings, ", ")
+             << "] but none allowed in target environment";
+    }
+  }
+  return success();
+}
+
+/// Checks that `candidates`capability requirements are possible to be satisfied
+/// with the given `allowedCapabilities` and updates `deducedCapabilities` if
+/// so. Emits errors attaching to the given `op` on failures.
+///
+///  `candidates` is a vector of vector for capability requirements following
+/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
+/// convention.
+static LogicalResult checkAndUpdateCapabilityRequirements(
+    Operation *op,
+    const llvm::SmallSet<spirv::Capability, 8> &allowedCapabilities,
+    const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
+    llvm::SetVector<spirv::Capability> &deducedCapabilities) {
+  for (const auto &ors : candidates) {
+    auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) {
+      return allowedCapabilities.count(cap);
+    });
+
+    if (chosen != ors.end()) {
+      deducedCapabilities.insert(*chosen);
+    } else {
+      SmallVector<StringRef, 4> capStrings;
+      for (spirv::Capability cap : ors)
+        capStrings.push_back(spirv::stringifyCapability(cap));
+
+      return op->emitError("'")
+             << op->getName() << "' requires at least one capability in ["
+             << llvm::join(capStrings, ", ")
+             << "] but none allowed in target environment";
+    }
+  }
+  return success();
+}
+
 void UpdateVCEPass::runOnOperation() {
   spirv::ModuleOp module = getOperation();
 
@@ -70,6 +135,7 @@ void UpdateVCEPass::runOnOperation() {
   // Walk each SPIR-V op to deduce the minimal version/extension/capability
   // requirements.
   WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
+    // Op min version requirements
     if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
       deducedVersion = std::max(deducedVersion, minVersion.getMinVersion());
       if (deducedVersion > allowedVersion) {
@@ -80,62 +146,44 @@ void UpdateVCEPass::runOnOperation() {
       }
     }
 
-    // Deduce this op's extension requirement. For each op, the query interfacce
-    // returns a vector of vector for its extension requirements following
-    // ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
-    // convention. Ops not implementing QueryExtensionInterface do not require
-    // extensions to be available.
-    if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) {
-      for (const auto &ors : extensions.getExtensions()) {
-        bool satisfied = false; // True when at least one extension can be used
-        for (spirv::Extension ext : ors) {
-          if (allowedExtensions.count(ext)) {
-            deducedExtensions.insert(ext);
-            satisfied = true;
-            break;
-          }
-        }
-
-        if (!satisfied) {
-          SmallVector<StringRef, 4> extStrings;
-          for (spirv::Extension ext : ors)
-            extStrings.push_back(spirv::stringifyExtension(ext));
-
-          return op->emitError("'")
-                 << op->getName() << "' requires at least one extension in ["
-                 << llvm::join(extStrings, ", ")
-                 << "] but none allowed in target environment";
-        }
-      }
-    }
-
-    // Deduce this op's capability requirement. For each op, the queryinterface
-    // returns a vector of vector for its capability requirements following
-    // ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
-    // convention. Ops not implementing QueryExtensionInterface do not require
-    // extensions to be available.
-    if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
-      for (const auto &ors : capabilities.getCapabilities()) {
-        bool satisfied = false; // True when at least one capability can be used
-        for (spirv::Capability cap : ors) {
-          if (allowedCapabilities.count(cap)) {
-            deducedCapabilities.insert(cap);
-            satisfied = true;
-            break;
-          }
-        }
-
-        if (!satisfied) {
-          SmallVector<StringRef, 4> capStrings;
-          for (spirv::Capability cap : ors)
-            capStrings.push_back(spirv::stringifyCapability(cap));
-
-          return op->emitError("'")
-                 << op->getName() << "' requires at least one capability in ["
-                 << llvm::join(capStrings, ", ")
-                 << "] but none allowed in target environment";
-        }
-      }
+    // Op extension requirements
+    if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
+      if (failed(checkAndUpdateExtensionRequirements(op, allowedExtensions,
+                                                     extensions.getExtensions(),
+                                                     deducedExtensions)))
+        return WalkResult::interrupt();
+
+    // Op capability requirements
+    if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
+      if (failed(checkAndUpdateCapabilityRequirements(
+              op, allowedCapabilities, capabilities.getCapabilities(),
+              deducedCapabilities)))
+        return WalkResult::interrupt();
+
+    SmallVector<Type, 4> valueTypes;
+    valueTypes.append(op->operand_type_begin(), op->operand_type_end());
+    valueTypes.append(op->result_type_begin(), op->result_type_end());
+
+    // Special treatment for global variables, whose type requirements are
+    // conveyed by type attributes.
+    if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
+      valueTypes.push_back(globalVar.type());
+
+    // Requirements from values' types
+    SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
+    SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
+    for (Type valueType : valueTypes) {
+      typeExtensions.clear();
+      valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
+      if (failed(checkAndUpdateExtensionRequirements(
+              op, allowedExtensions, typeExtensions, deducedExtensions)))
+        return WalkResult::interrupt();
+
+      typeCapabilities.clear();
+      valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
+      if (failed(checkAndUpdateCapabilityRequirements(
+              op, allowedCapabilities, typeCapabilities, deducedCapabilities)))
+        return WalkResult::interrupt();
     }
 
     return WalkResult::advance();

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 60bf13e2571e..572db88e5f9e 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -107,6 +107,36 @@ spv.module Logical GLSL450 attributes {
   }
 }
 
+// Test type required capabilities
+
+// Using 8-bit integers in non-interface storage class requires Int8.
+// CHECK: requires #spv.vce<v1.0, [Int8, Shader], []>
+spv.module Logical GLSL450 attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.3, [Shader, Int8], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+  spv.func @iadd_function(%val : i8) -> i8 "None" {
+    %0 = spv.IAdd %val, %val : i8
+    spv.ReturnValue %0: i8
+  }
+}
+
+// Using 16-bit floats in non-interface storage class requires Float16.
+// CHECK: requires #spv.vce<v1.0, [Float16, Shader], []>
+spv.module Logical GLSL450 attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.3, [Shader, Float16], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+  spv.func @fadd_function(%val : f16) -> f16 "None" {
+    %0 = spv.FAdd %val, %val : f16
+    spv.ReturnValue %0: f16
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Extension
 //===----------------------------------------------------------------------===//
@@ -144,3 +174,35 @@ spv.module Logical Vulkan attributes {
     spv.ReturnValue %0: i32
   }
 }
+
+// Test type required extensions
+
+// Using 8-bit integers in interface storage class requires additional
+// extensions and capabilities.
+// CHECK: requires #spv.vce<v1.0, [StorageBuffer16BitAccess, Shader, Int16], [SPV_KHR_16bit_storage, SPV_KHR_storage_buffer_storage_class]>
+spv.module Logical GLSL450 attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.3, [Shader, StorageBuffer16BitAccess, Int16], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+  spv.func @iadd_storage_buffer(%ptr : !spv.ptr<i16, StorageBuffer>) -> i16 "None" {
+    %0 = spv.Load "StorageBuffer" %ptr : i16
+    %1 = spv.IAdd %0, %0 : i16
+    spv.ReturnValue %1: i16
+  }
+}
+
+// Complicated nested types
+// * Buffer requires ImageBuffer or SampledBuffer.
+// * Rg32f requires StorageImageExtendedFormats.
+// CHECK: requires #spv.vce<v1.0, [UniformAndStorageBuffer8BitAccess, StorageUniform16, Shader, ImageBuffer, StorageImageExtendedFormats], [SPV_KHR_8bit_storage, SPV_KHR_16bit_storage]>
+spv.module Logical GLSL450 attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.5, [Shader, UniformAndStorageBuffer8BitAccess, StorageBuffer16BitAccess, StorageUniform16, Int16, ImageBuffer, StorageImageExtendedFormats], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+  spv.globalVariable @data : !spv.ptr<!spv.struct<i8 [0], f16 [2], i64 [4]>, Uniform>
+  spv.globalVariable @img  : !spv.ptr<!spv.image<f32, Buffer, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Rg32f>, UniformConstant>
+}


        


More information about the Mlir-commits mailing list