[llvm] [NFC][SPIRV] Re-work extension parsing (PR #171826)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 11 05:12:29 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Alex Voicu (AlexVlx)

<details>
<summary>Changes</summary>

This changes the extension parsing mechanism underpinning `--spirv-ext` to be more explicit about what it is doing and not rely on a sort. More specifically, we partition extensions into enabled (prefixed with `+`) and others, and then individually handle the resulting ranges. 

---
Full diff: https://github.com/llvm/llvm-project/pull/171826.diff


1 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+30-33) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 42edad255ce82..04c54f9b0e53d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -17,7 +17,9 @@
 #include "llvm/TargetParser/Triple.h"
 
 #include <functional>
+#include <iterator>
 #include <map>
+#include <set>
 #include <string>
 #include <utility>
 #include <vector>
@@ -26,7 +28,7 @@
 
 using namespace llvm;
 
-static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
+static const std::map<StringRef, SPIRV::Extension::Extension>
     SPIRVExtensionMap = {
         {"SPV_EXT_shader_atomic_float_add",
          SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_add},
@@ -181,57 +183,52 @@ bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   std::set<SPIRV::Extension::Extension> &Vals) {
   SmallVector<StringRef, 10> Tokens;
   ArgValue.split(Tokens, ",", -1, false);
-  llvm::sort(Tokens, [](auto &&LHS, auto &&RHS) {
-    // We want to ensure that we handle "all" first, to ensure that any
-    // subsequent disablement actually behaves as expected i.e. given
-    // --spv-ext=all,-foo, we first enable all and then disable foo; this should
-    // be revisited and simplified.
-    if (LHS == "all")
-      return true;
-    if (RHS == "all")
-      return false;
-    return !(RHS < LHS);
-  });
 
   std::set<SPIRV::Extension::Extension> EnabledExtensions;
 
-  for (const auto &Token : Tokens) {
-    if (Token == "all") {
-      for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
-        EnabledExtensions.insert(ExtensionEnum);
+  auto M = partition(Tokens, [](auto &&T) { return T.starts_with('+'); });
+
+  if (std::any_of(M, Tokens.end(), [](auto &&T) { return T == "all"; }))
+    copy(make_second_range(SPIRVExtensionMap), std::inserter(Vals, Vals.end()));
+
+  for (auto &&Token : make_range(Tokens.begin(), M)) {
+    StringRef ExtensionName = Token.substr(1);
+    auto NameValuePair = SPIRVExtensionMap.find(ExtensionName);
 
+    if (NameValuePair == SPIRVExtensionMap.end())
+      return O.error("Unknown SPIR-V extension: " + Token.str());
+
+    EnabledExtensions.insert(NameValuePair->second);
+  }
+
+  for (auto &&Token : make_range(M, Tokens.end())) {
+    if (Token == "all")
       continue;
-    }
 
     if (Token.size() == 3 && Token.upper() == "KHR") {
       for (const auto &[ExtensionName, ExtensionEnum] : SPIRVExtensionMap)
         if (StringRef(ExtensionName).starts_with("SPV_KHR_"))
-          EnabledExtensions.insert(ExtensionEnum);
+          Vals.insert(ExtensionEnum);
       continue;
     }
 
     if (Token.empty() || (!Token.starts_with("+") && !Token.starts_with("-")))
-      return O.error("Invalid extension list format: " + Token.str());
+      return O.error("Invalid extension list format: " + Token);
 
-    StringRef ExtensionName = Token.substr(1);
-    auto NameValuePair = SPIRVExtensionMap.find(ExtensionName);
+    auto NameValuePair = SPIRVExtensionMap.find(Token.substr(1));
 
-    if (NameValuePair == SPIRVExtensionMap.end())
+    if (NameValuePair == SPIRVExtensionMap.cend())
       return O.error("Unknown SPIR-V extension: " + Token.str());
+    if (EnabledExtensions.count(NameValuePair->second))
+      return O.error(
+          "Extension cannot be allowed and disallowed at the same time: " +
+          NameValuePair->first);
 
-    if (Token.starts_with("+")) {
-      EnabledExtensions.insert(NameValuePair->second);
-    } else if (EnabledExtensions.count(NameValuePair->second)) {
-      if (llvm::is_contained(Tokens, "+" + ExtensionName.str()))
-        return O.error(
-            "Extension cannot be allowed and disallowed at the same time: " +
-            ExtensionName.str());
-
-      EnabledExtensions.erase(NameValuePair->second);
-    }
+    Vals.erase(NameValuePair->second);
   }
 
-  Vals = std::move(EnabledExtensions);
+  Vals.insert(EnabledExtensions.cbegin(), EnabledExtensions.cend());
+
   return false;
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/171826


More information about the llvm-commits mailing list