[Mlir-commits] [mlir] 58df5e6 - [mlir][spirv] Plumbing target environment into type converter

Lei Zhang llvmlistbot at llvm.org
Wed Mar 18 17:13:24 PDT 2020


Author: Lei Zhang
Date: 2020-03-18T20:11:05-04:00
New Revision: 58df5e6d9ad93041d9c9b53a9e24d2be79796762

URL: https://github.com/llvm/llvm-project/commit/58df5e6d9ad93041d9c9b53a9e24d2be79796762
DIFF: https://github.com/llvm/llvm-project/commit/58df5e6d9ad93041d9c9b53a9e24d2be79796762.diff

LOG: [mlir][spirv] Plumbing target environment into type converter

This commit unifies target environment queries into a new wrapper
class spirv::TargetEnv and shares across various places needing
the functionality. We still create multiple instances of TargetEnv
though given the parent components (type converters, passes,
conversion targets) have different lifetimes.

In the meantime, LowerABIAttributesPass is updated to take into
consideration the target environment, which requires updates to
tests to provide that.

Differential Revision: https://reviews.llvm.org/D76242

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
    mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
    mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
    mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
    mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
    mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
    mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
    mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
    mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
    mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
    mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index b29c62e67116..a97c83b7a553 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
 #define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
 
+#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -27,7 +28,7 @@ namespace mlir {
 /// pointers to structs.
 class SPIRVTypeConverter : public TypeConverter {
 public:
-  SPIRVTypeConverter();
+  explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr);
 
   /// Gets the SPIR-V correspondence for the standard index type.
   static Type getIndexType(MLIRContext *context);
@@ -40,6 +41,9 @@ class SPIRVTypeConverter : public TypeConverter {
   /// llvm::None if the memory space does not map to any SPIR-V storage class.
   static Optional<spirv::StorageClass>
   getStorageClassForMemorySpace(unsigned space);
+
+private:
+  spirv::TargetEnv targetEnv;
 };
 
 /// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V.
@@ -70,11 +74,10 @@ class FuncOp;
 class SPIRVConversionTarget : public ConversionTarget {
 public:
   /// Creates a SPIR-V conversion target for the given target environment.
-  static std::unique_ptr<SPIRVConversionTarget> get(TargetEnvAttr targetEnv,
-                                                    MLIRContext *context);
+  static std::unique_ptr<SPIRVConversionTarget> get(TargetEnvAttr targetAttr);
 
 private:
-  SPIRVConversionTarget(TargetEnvAttr targetEnv, MLIRContext *context);
+  explicit SPIRVConversionTarget(TargetEnvAttr targetAttr);
 
   // Be explicit that instance of this class cannot be copied or moved: there
   // are lambdas capturing fields of the instance.
@@ -87,9 +90,7 @@ class SPIRVConversionTarget : public ConversionTarget {
   /// environment.
   bool isLegalOp(Operation *op);
 
-  Version givenVersion;                            /// SPIR-V version to target
-  llvm::SmallSet<Extension, 4> givenExtensions;    /// Allowed extensions
-  llvm::SmallSet<Capability, 8> givenCapabilities; /// Allowed capabilities
+  TargetEnv targetEnv;
 };
 
 /// Returns the value for the given `builtin` variable. This function gets or

diff  --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
index bf9c51e5b110..27278b6d3f23 100644
--- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
+++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
 
 namespace mlir {
 class Operation;
@@ -22,6 +23,38 @@ class Operation;
 namespace spirv {
 enum class StorageClass : uint32_t;
 
+/// A wrapper class around a spirv::TargetEnvAttr to provide query methods for
+/// allowed version/capabilities/extensions.
+class TargetEnv {
+public:
+  explicit TargetEnv(TargetEnvAttr targetAttr);
+
+  Version getVersion();
+
+  /// Returns true if the given capability is allowed.
+  bool allows(Capability) const;
+  /// Returns the first allowed one if any of the given capabilities is allowed.
+  /// Returns llvm::None otherwise.
+  Optional<Capability> allows(ArrayRef<Capability>) const;
+
+  /// Returns true if the given extension is allowed.
+  bool allows(Extension) const;
+  /// Returns the first allowed one if any of the given extensions is allowed.
+  /// Returns llvm::None otherwise.
+  Optional<Extension> allows(ArrayRef<Extension>) const;
+
+  /// Returns the MLIRContext.
+  MLIRContext *getContext();
+
+  /// Allows implicity converting to the underlying spirv::TargetEnvAttr.
+  operator TargetEnvAttr() const { return targetAttr; }
+
+private:
+  TargetEnvAttr targetAttr;
+  llvm::SmallSet<Extension, 4> givenExtensions;    /// Allowed extensions
+  llvm::SmallSet<Capability, 8> givenCapabilities; /// Allowed capabilities
+};
+
 /// Returns the attribute name for specifying argument ABI information.
 StringRef getInterfaceVarABIAttrName();
 

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
index 4b84bc424fbd..272eb163ab69 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp
@@ -52,14 +52,15 @@ void GPUToSPIRVPass::runOnModule() {
     kernelModules.push_back(builder.clone(*moduleOp.getOperation()));
   });
 
-  SPIRVTypeConverter typeConverter;
+  auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+  std::unique_ptr<ConversionTarget> target =
+      spirv::SPIRVConversionTarget::get(targetAttr);
+
+  SPIRVTypeConverter typeConverter(targetAttr);
   OwningRewritePatternList patterns;
   populateGPUToSPIRVPatterns(context, typeConverter, patterns);
   populateStandardToSPIRVPatterns(context, typeConverter, patterns);
 
-  std::unique_ptr<ConversionTarget> target = spirv::SPIRVConversionTarget::get(
-      spirv::lookupTargetEnvOrDefault(module), context);
-
   if (failed(applyFullConversion(kernelModules, *target, patterns,
                                  &typeConverter))) {
     return signalPassFailure();

diff  --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
index 68d31ca72479..4477c070796e 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
@@ -25,15 +25,15 @@ void LinalgToSPIRVPass::runOnModule() {
   MLIRContext *context = &getContext();
   ModuleOp module = getModule();
 
-  SPIRVTypeConverter typeConverter;
+  auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+  std::unique_ptr<ConversionTarget> target =
+      spirv::SPIRVConversionTarget::get(targetAttr);
+
+  SPIRVTypeConverter typeConverter(targetAttr);
   OwningRewritePatternList patterns;
   populateLinalgToSPIRVPatterns(context, typeConverter, patterns);
   populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
 
-  auto targetEnv = spirv::lookupTargetEnvOrDefault(module);
-  std::unique_ptr<ConversionTarget> target =
-      spirv::SPIRVConversionTarget::get(targetEnv, context);
-
   // Allow builtin ops.
   target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
   target->addDynamicallyLegalOp<FuncOp>(

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
index 7a3dae287d70..efccd168d6ea 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
@@ -31,14 +31,15 @@ void ConvertStandardToSPIRVPass::runOnModule() {
   MLIRContext *context = &getContext();
   ModuleOp module = getModule();
 
-  SPIRVTypeConverter typeConverter;
+  auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
+  std::unique_ptr<ConversionTarget> target =
+      spirv::SPIRVConversionTarget::get(targetAttr);
+
+  SPIRVTypeConverter typeConverter(targetAttr);
   OwningRewritePatternList patterns;
   populateStandardToSPIRVPatterns(context, typeConverter, patterns);
   populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
 
-  std::unique_ptr<ConversionTarget> target = spirv::SPIRVConversionTarget::get(
-      spirv::lookupTargetEnvOrDefault(module), context);
-
   if (failed(applyPartialConversion(module, *target, patterns))) {
     return signalPassFailure();
   }

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 6d73432fead4..e5b630b82fb1 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -159,7 +159,8 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
   return llvm::None;
 }
 
-SPIRVTypeConverter::SPIRVTypeConverter() {
+SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
+    : targetEnv(targetAttr) {
   addConversion([](Type type) -> Optional<Type> {
     // If the type is already valid in SPIR-V, directly return.
     return spirv::SPIRVDialect::isValidType(type) ? type : Optional<Type>();
@@ -411,11 +412,10 @@ mlir::spirv::setABIAttrs(spirv::FuncOp funcOp,
 //===----------------------------------------------------------------------===//
 
 std::unique_ptr<spirv::SPIRVConversionTarget>
-spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetEnv,
-                                  MLIRContext *context) {
+spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
   std::unique_ptr<SPIRVConversionTarget> target(
       // std::make_unique does not work here because the constructor is private.
-      new SPIRVConversionTarget(targetEnv, context));
+      new SPIRVConversionTarget(targetAttr));
   SPIRVConversionTarget *targetPtr = target.get();
   target->addDynamicallyLegalDialect<SPIRVDialect>(
       Optional<ConversionTarget::DynamicLegalityCallbackFn>(
@@ -426,80 +426,57 @@ spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetEnv,
 }
 
 spirv::SPIRVConversionTarget::SPIRVConversionTarget(
-    spirv::TargetEnvAttr targetEnv, MLIRContext *context)
-    : ConversionTarget(*context), givenVersion(targetEnv.getVersion()) {
-  for (spirv::Extension ext : targetEnv.getExtensions())
-    givenExtensions.insert(ext);
-
-  // Add extensions implied by the current version.
-  for (spirv::Extension ext : spirv::getImpliedExtensions(givenVersion))
-    givenExtensions.insert(ext);
-
-  for (spirv::Capability cap : targetEnv.getCapabilities()) {
-    givenCapabilities.insert(cap);
-
-    // Add capabilities implied by the current capability.
-    for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
-      givenCapabilities.insert(c);
-  }
-}
+    spirv::TargetEnvAttr targetAttr)
+    : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
 
 /// Checks that `candidates` extension requirements are possible to be satisfied
-/// with the given `allowedExtensions`.
+/// with the given `targetEnv`.
 ///
 ///  `candidates` is a vector of vector for extension requirements following
 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
 /// convention.
 static LogicalResult checkExtensionRequirements(
-    Operation *op, const llvm::SmallSet<spirv::Extension, 4> &allowedExtensions,
+    Operation *op, const spirv::TargetEnv &targetEnv,
     const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
   for (const auto &ors : candidates) {
-    auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) {
-      return allowedExtensions.count(ext);
-    });
-
-    if (chosen == ors.end()) {
-      SmallVector<StringRef, 4> extStrings;
-      for (spirv::Extension ext : ors)
-        extStrings.push_back(spirv::stringifyExtension(ext));
-
-      LLVM_DEBUG(llvm::dbgs() << op->getName()
-                              << "illegal: requires at least one extension in ["
-                              << llvm::join(extStrings, ", ")
-                              << "] but none allowed in target environment\n");
-      return failure();
-    }
+    if (targetEnv.allows(ors))
+      continue;
+
+    SmallVector<StringRef, 4> extStrings;
+    for (spirv::Extension ext : ors)
+      extStrings.push_back(spirv::stringifyExtension(ext));
+
+    LLVM_DEBUG(llvm::dbgs() << op->getName()
+                            << " illegal: requires at least one extension in ["
+                            << llvm::join(extStrings, ", ")
+                            << "] but none allowed in target environment\n");
+    return failure();
   }
   return success();
 }
 
 /// Checks that `candidates`capability requirements are possible to be satisfied
-/// with the given `allowedCapabilities`.
+/// with the given `isAllowedFn`.
 ///
 ///  `candidates` is a vector of vector for capability requirements following
 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
 /// convention.
 static LogicalResult checkCapabilityRequirements(
-    Operation *op,
-    const llvm::SmallSet<spirv::Capability, 8> &allowedCapabilities,
+    Operation *op, const spirv::TargetEnv &targetEnv,
     const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
   for (const auto &ors : candidates) {
-    auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) {
-      return allowedCapabilities.count(cap);
-    });
+    if (targetEnv.allows(ors))
+      continue;
 
-    if (chosen == ors.end()) {
-      SmallVector<StringRef, 4> capStrings;
-      for (spirv::Capability cap : ors)
-        capStrings.push_back(spirv::stringifyCapability(cap));
+    SmallVector<StringRef, 4> capStrings;
+    for (spirv::Capability cap : ors)
+      capStrings.push_back(spirv::stringifyCapability(cap));
 
-      LLVM_DEBUG(llvm::dbgs()
-                 << op->getName()
-                 << "illegal: requires at least one capability in ["
-                 << llvm::join(capStrings, ", ")
-                 << "] but none allowed in target environment\n");
-      return failure();
-    }
+    LLVM_DEBUG(llvm::dbgs() << op->getName()
+                            << " illegal: requires at least one capability in ["
+                            << llvm::join(capStrings, ", ")
+                            << "] but none allowed in target environment\n");
+    return failure();
   }
   return success();
 }
@@ -509,7 +486,7 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
   // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
   // SPIR-V versions.
   if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
-    if (minVersion.getMinVersion() > givenVersion) {
+    if (minVersion.getMinVersion() > this->targetEnv.getVersion()) {
       LLVM_DEBUG(llvm::dbgs()
                  << op->getName() << " illegal: requiring min version "
                  << spirv::stringifyVersion(minVersion.getMinVersion())
@@ -517,7 +494,7 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
       return false;
     }
   if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
-    if (maxVersion.getMaxVersion() < givenVersion) {
+    if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) {
       LLVM_DEBUG(llvm::dbgs()
                  << op->getName() << " illegal: requiring max version "
                  << spirv::stringifyVersion(maxVersion.getMaxVersion())
@@ -529,7 +506,7 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
   // implementing QueryExtensionInterface do not require extensions to be
   // available.
   if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
-    if (failed(checkExtensionRequirements(op, this->givenExtensions,
+    if (failed(checkExtensionRequirements(op, this->targetEnv,
                                           extensions.getExtensions())))
       return false;
 
@@ -537,7 +514,7 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
   // implementing QueryCapabilityInterface do not require capabilities to be
   // available.
   if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
-    if (failed(checkCapabilityRequirements(op, this->givenCapabilities,
+    if (failed(checkCapabilityRequirements(op, this->targetEnv,
                                            capabilities.getCapabilities())))
       return false;
 
@@ -557,14 +534,13 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
   for (Type valueType : valueTypes) {
     typeExtensions.clear();
     valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
-    if (failed(checkExtensionRequirements(op, this->givenExtensions,
-                                          typeExtensions)))
+    if (failed(checkExtensionRequirements(op, this->targetEnv, typeExtensions)))
       return false;
 
     typeCapabilities.clear();
     valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
-    if (failed(checkCapabilityRequirements(op, this->givenCapabilities,
-                                           typeCapabilities)))
+    if (failed(
+            checkCapabilityRequirements(op, this->targetEnv, typeCapabilities)))
       return false;
   }
 

diff  --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
index 16f79349ac1e..779a752e2e6c 100644
--- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
@@ -15,6 +15,67 @@
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// TargetEnv
+//===----------------------------------------------------------------------===//
+
+spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr)
+    : targetAttr(targetAttr) {
+  for (spirv::Extension ext : targetAttr.getExtensions())
+    givenExtensions.insert(ext);
+
+  // Add extensions implied by the current version.
+  for (spirv::Extension ext :
+       spirv::getImpliedExtensions(targetAttr.getVersion()))
+    givenExtensions.insert(ext);
+
+  for (spirv::Capability cap : targetAttr.getCapabilities()) {
+    givenCapabilities.insert(cap);
+
+    // Add capabilities implied by the current capability.
+    for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
+      givenCapabilities.insert(c);
+  }
+}
+
+spirv::Version spirv::TargetEnv::getVersion() {
+  return targetAttr.getVersion();
+}
+
+bool spirv::TargetEnv::allows(spirv::Capability capability) const {
+  return givenCapabilities.count(capability);
+}
+
+Optional<spirv::Capability>
+spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const {
+  auto chosen = llvm::find_if(caps, [this](spirv::Capability cap) {
+    return givenCapabilities.count(cap);
+  });
+  if (chosen != caps.end())
+    return *chosen;
+  return llvm::None;
+}
+
+bool spirv::TargetEnv::allows(spirv::Extension extension) const {
+  return givenExtensions.count(extension);
+}
+
+Optional<spirv::Extension>
+spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const {
+  auto chosen = llvm::find_if(exts, [this](spirv::Extension ext) {
+    return givenExtensions.count(ext);
+  });
+  if (chosen != exts.end())
+    return *chosen;
+  return llvm::None;
+}
+
+MLIRContext *spirv::TargetEnv::getContext() { return targetAttr.getContext(); }
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
 StringRef spirv::getInterfaceVarABIAttrName() {
   return "spv.interface_var_abi";
 }

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index cb986fd8b282..516b9eca8544 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -224,7 +224,9 @@ void LowerABIAttributesPass::runOnOperation() {
   spirv::ModuleOp module = getOperation();
   MLIRContext *context = &getContext();
 
-  SPIRVTypeConverter typeConverter;
+  spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(module));
+
+  SPIRVTypeConverter typeConverter(targetEnv);
   OwningRewritePatternList patterns;
   patterns.insert<ProcessInterfaceVarABI>(context, typeConverter);
 

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index fff15c185749..201adbbd3837 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -34,22 +34,18 @@ class UpdateVCEPass final
 } // namespace
 
 /// Checks that `candidates` extension requirements are possible to be satisfied
-/// with the given `allowedExtensions` and updates `deducedExtensions` if so.
-/// Emits errors attaching to the given `op` on failures.
+/// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
+/// errors attaching to the given `op` on failures.
 ///
 ///  `candidates` is a vector of vector for extension requirements following
 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
 /// convention.
 static LogicalResult checkAndUpdateExtensionRequirements(
-    Operation *op, const llvm::SmallSet<spirv::Extension, 4> &allowedExtensions,
+    Operation *op, const spirv::TargetEnv &targetEnv,
     const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
     llvm::SetVector<spirv::Extension> &deducedExtensions) {
   for (const auto &ors : candidates) {
-    auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) {
-      return allowedExtensions.count(ext);
-    });
-
-    if (chosen != ors.end()) {
+    if (Optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
       deducedExtensions.insert(*chosen);
     } else {
       SmallVector<StringRef, 4> extStrings;
@@ -66,23 +62,18 @@ static LogicalResult checkAndUpdateExtensionRequirements(
 }
 
 /// Checks that `candidates`capability requirements are possible to be satisfied
-/// with the given `allowedCapabilities` and updates `deducedCapabilities` if
-/// so. Emits errors attaching to the given `op` on failures.
+/// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
+/// errors attaching to the given `op` on failures.
 ///
 ///  `candidates` is a vector of vector for capability requirements following
 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
 /// convention.
 static LogicalResult checkAndUpdateCapabilityRequirements(
-    Operation *op,
-    const llvm::SmallSet<spirv::Capability, 8> &allowedCapabilities,
+    Operation *op, const spirv::TargetEnv &targetEnv,
     const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
     llvm::SetVector<spirv::Capability> &deducedCapabilities) {
   for (const auto &ors : candidates) {
-    auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) {
-      return allowedCapabilities.count(cap);
-    });
-
-    if (chosen != ors.end()) {
+    if (Optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
       deducedCapabilities.insert(*chosen);
     } else {
       SmallVector<StringRef, 4> capStrings;
@@ -101,32 +92,14 @@ static LogicalResult checkAndUpdateCapabilityRequirements(
 void UpdateVCEPass::runOnOperation() {
   spirv::ModuleOp module = getOperation();
 
-  spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnv(module);
-  if (!targetEnv) {
+  spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
+  if (!targetAttr) {
     module.emitError("missing 'spv.target_env' attribute");
     return signalPassFailure();
   }
 
-  spirv::Version allowedVersion = targetEnv.getVersion();
-
-  // Build a set for available extensions in the target environment.
-  llvm::SmallSet<spirv::Extension, 4> allowedExtensions;
-  for (spirv::Extension ext : targetEnv.getExtensions())
-    allowedExtensions.insert(ext);
-
-  // Add extensions implied by the current version.
-  for (spirv::Extension ext : spirv::getImpliedExtensions(allowedVersion))
-    allowedExtensions.insert(ext);
-
-  // Build a set for available capabilities in the target environment.
-  llvm::SmallSet<spirv::Capability, 8> allowedCapabilities;
-  for (spirv::Capability cap : targetEnv.getCapabilities()) {
-    allowedCapabilities.insert(cap);
-
-    // Add capabilities implied by the current capability.
-    for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
-      allowedCapabilities.insert(c);
-  }
+  spirv::TargetEnv targetEnv(targetAttr);
+  spirv::Version allowedVersion = targetAttr.getVersion();
 
   spirv::Version deducedVersion = spirv::Version::V_1_0;
   llvm::SetVector<spirv::Extension> deducedExtensions;
@@ -148,15 +121,14 @@ void UpdateVCEPass::runOnOperation() {
 
     // Op extension requirements
     if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
-      if (failed(checkAndUpdateExtensionRequirements(op, allowedExtensions,
-                                                     extensions.getExtensions(),
-                                                     deducedExtensions)))
+      if (failed(checkAndUpdateExtensionRequirements(
+              op, targetEnv, extensions.getExtensions(), deducedExtensions)))
         return WalkResult::interrupt();
 
     // Op capability requirements
     if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
       if (failed(checkAndUpdateCapabilityRequirements(
-              op, allowedCapabilities, capabilities.getCapabilities(),
+              op, targetEnv, capabilities.getCapabilities(),
               deducedCapabilities)))
         return WalkResult::interrupt();
 
@@ -176,13 +148,13 @@ void UpdateVCEPass::runOnOperation() {
       typeExtensions.clear();
       valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
       if (failed(checkAndUpdateExtensionRequirements(
-              op, allowedExtensions, typeExtensions, deducedExtensions)))
+              op, targetEnv, typeExtensions, deducedExtensions)))
         return WalkResult::interrupt();
 
       typeCapabilities.clear();
       valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
       if (failed(checkAndUpdateCapabilityRequirements(
-              op, allowedCapabilities, typeCapabilities, deducedCapabilities)))
+              op, targetEnv, typeCapabilities, deducedCapabilities)))
         return WalkResult::interrupt();
     }
 

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index a1f662300412..3972def985bb 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -1,5 +1,12 @@
 // RUN: mlir-opt -spirv-lower-abi-attrs -verify-diagnostics %s -o - | FileCheck %s
 
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
 // CHECK-LABEL: spv.module
 spv.module Logical GLSL450 {
   // CHECK-DAG:    spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>
@@ -24,4 +31,6 @@ spv.module Logical GLSL450 {
   }
   // CHECK: spv.EntryPoint "GLCompute" [[FN]]
   // CHECK: spv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
-}
+} // end spv.module
+
+} // end module

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
index f3158e310d79..42ff3f55e1ea 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
@@ -1,5 +1,12 @@
 // RUN: mlir-opt -spirv-lower-abi-attrs -verify-diagnostics %s -o - | FileCheck %s
 
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
 // CHECK-LABEL: spv.module
 spv.module Logical GLSL450 {
   // CHECK-DAG: spv.globalVariable [[WORKGROUPSIZE:@.*]] built_in("WorkgroupSize")
@@ -119,4 +126,6 @@ spv.module Logical GLSL450 {
   }
   // CHECK: spv.EntryPoint "GLCompute" [[FN]], [[WORKGROUPID]], [[LOCALINVOCATIONID]], [[NUMWORKGROUPS]], [[WORKGROUPSIZE]]
   // CHECK-NEXT: spv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
-}
+} // end spv.module
+
+} // end module

diff  --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index a91800d68fc0..ad77e7d05f42 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -130,7 +130,12 @@ void ConvertToTargetEnv::runOnFunction() {
   auto targetEnv = fn.getOperation()
                        ->getAttr(spirv::getTargetEnvAttrName())
                        .cast<spirv::TargetEnvAttr>();
-  auto target = spirv::SPIRVConversionTarget::get(targetEnv, context);
+  if (!targetEnv) {
+    fn.emitError("missing 'spv.target_env' attribute");
+    return signalPassFailure();
+  }
+
+  auto target = spirv::SPIRVConversionTarget::get(targetEnv);
 
   OwningRewritePatternList patterns;
   patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,


        


More information about the Mlir-commits mailing list