[llvm] [SPIR-V] Expose an API call to initialize SPIRV target and translate input LLVM IR module to SPIR-V (PR #107216)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 10 03:22:51 PDT 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/107216

>From 66f792ddd769ee70e04e858a56a3adc6070acc57 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 4 Sep 2024 03:41:54 -0700
Subject: [PATCH 1/5] the first draft of external SPIRV API

---
 llvm/lib/Target/SPIRV/CMakeLists.txt         |   1 +
 llvm/lib/Target/SPIRV/SPIRV.cpp              | 160 +++++++++++++++++++
 llvm/unittests/Target/SPIRV/CMakeLists.txt   |   2 +
 llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp |  98 ++++++++++++
 4 files changed, 261 insertions(+)
 create mode 100644 llvm/lib/Target/SPIRV/SPIRV.cpp
 create mode 100644 llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp

diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index 5f8aea5fc8d84d..2c5c053fe72173 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -14,6 +14,7 @@ tablegen(LLVM SPIRVGenTables.inc -gen-searchable-tables)
 add_public_tablegen_target(SPIRVCommonTableGen)
 
 add_llvm_target(SPIRVCodeGen
+  SPIRV.cpp
   SPIRVAsmPrinter.cpp
   SPIRVBuiltins.cpp
   SPIRVCallLowering.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.cpp b/llvm/lib/Target/SPIRV/SPIRV.cpp
new file mode 100644
index 00000000000000..2636efb26128bf
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRV.cpp
@@ -0,0 +1,160 @@
+//===-- SPIRV.cpp - SPIR-V Backend API ------------------------*- C++ -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/CodeGen/CommandFlags.h"
+// #include "llvm/CodeGen/LinkAllAsmWriterComponents.h"
+// #include "llvm/CodeGen/LinkAllCodegenComponents.h"
+// #include "llvm/CodeGen/MIRParser/MIRParser.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/CodeGen/TargetSubtargetInfo.h"
+// #include "llvm/IR/AutoUpgrade.h"
+#include "llvm/IR/DataLayout.h"
+// #include "llvm/IR/DiagnosticInfo.h"
+// #include "llvm/IR/DiagnosticPrinter.h"
+#include "llvm/IR/LLVMContext.h"
+// #include "llvm/IR/LLVMRemarkStreamer.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Verifier.h"
+// #include "llvm/IRReader/IRReader.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/MC/MCTargetOptionsCommandFlags.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Pass.h"
+// #include "llvm/Remarks/HotnessThresholdParser.h"
+#include "llvm/Support/CommandLine.h"
+// #include "llvm/Support/Debug.h"
+// #include "llvm/Support/FileSystem.h"
+#include "llvm/Support/FormattedStream.h"
+#include "llvm/Support/InitLLVM.h"
+// #include "llvm/Support/PluginLoader.h"
+// #include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/TargetSelect.h"
+// #include "llvm/Support/TimeProfiler.h"
+// #include "llvm/Support/ToolOutputFile.h"
+// #include "llvm/Support/WithColor.h"
+#include "llvm/Target/TargetLoweringObjectFile.h"
+#include "llvm/Target/TargetMachine.h"
+// #include "llvm/TargetParser/Host.h"
+#include "llvm/TargetParser/SubtargetFeature.h"
+#include "llvm/TargetParser/Triple.h"
+// #include "llvm/Transforms/Utils/Cloning.h"
+// #include <memory>
+#include <optional>
+// #include <ostream>
+#include <string>
+#include <utility>
+
+using namespace llvm;
+
+namespace {
+void parseSPIRVCommandLineOptions(const std::vector<std::string> &Options,
+                                  raw_ostream *Errs) {
+  static constexpr const char *Origin = "SPIRVTranslateModule";
+  if (!Options.empty()) {
+    std::vector<const char *> Argv(1, Origin);
+    for (const auto& Arg : Options)
+      Argv.push_back(Arg.c_str());
+    cl::ParseCommandLineOptions(Argv.size(), Argv.data(), Origin, Errs);
+  }
+}
+
+std::once_flag InitOnceFlag;
+void InitializeSPIRVTarget() {
+  std::call_once(InitOnceFlag, []() {
+    LLVMInitializeSPIRVTargetInfo();
+    LLVMInitializeSPIRVTarget();
+    LLVMInitializeSPIRVTargetMC();
+    LLVMInitializeSPIRVAsmPrinter();
+  });
+}
+} // namespace
+
+extern "C" LLVM_EXTERNAL_VISIBILITY bool
+SPIRVTranslateModule(Module *M, std::string &SpirvObj, std::string &ErrMsg,
+                     const std::vector<std::string> &Opts) {
+  // Fallbacks for a Triple, MArch, Opt-level values.
+  static const std::string DefaultTriple = "spirv64-unknown-unknown";
+  static const std::string DefaultMArch = "";
+  static const llvm::CodeGenOptLevel OLevel = llvm::CodeGenOptLevel::None;
+
+  // Parse Opts as if it'd be command line argument.
+  std::string Errors;
+  raw_string_ostream ErrorStream(Errors);
+  parseSPIRVCommandLineOptions(Opts, &ErrorStream);
+  if (!Errors.empty()) {
+    ErrMsg = Errors;
+    return false;
+  }
+
+  // SPIR-V-specific target initialization.
+  InitializeSPIRVTarget();
+
+  Triple TargetTriple(M->getTargetTriple());
+  if (TargetTriple.getTriple().empty()) {
+    TargetTriple.setTriple(DefaultTriple);
+    M->setTargetTriple(DefaultTriple);
+  }
+  const Target *TheTarget =
+      TargetRegistry::lookupTarget(DefaultMArch, TargetTriple, ErrMsg);
+  if (!TheTarget)
+    return false;
+
+  // A call to codegen::InitTargetOptionsFromCodeGenFlags(TargetTriple)
+  // hits the following assertion: llvm/lib/CodeGen/CommandFlags.cpp:78:
+  // llvm::FPOpFusion::FPOpFusionMode llvm::codegen::getFuseFPOps(): Assertion
+  // `FuseFPOpsView && "RegisterCodeGenFlags not created."' failed.
+  TargetOptions Options;
+  std::optional<Reloc::Model> RM;
+  std::optional<CodeModel::Model> CM;
+  std::unique_ptr<TargetMachine> Target =
+      std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
+          TargetTriple.getTriple(), "", "", Options, RM, CM, OLevel));
+  if (!Target) {
+    ErrMsg = "Could not allocate target machine!";
+    return false;
+  }
+
+  if (M->getCodeModel())
+    Target->setCodeModel(*M->getCodeModel());
+
+  std::string DLStr = M->getDataLayoutStr();
+  Expected<DataLayout> MaybeDL = DataLayout::parse(
+      DLStr.empty() ? Target->createDataLayout().getStringRepresentation()
+                    : DLStr);
+  if (!MaybeDL) {
+    ErrMsg = toString(MaybeDL.takeError());
+    return false;
+  }
+  M->setDataLayout(MaybeDL.get());
+
+  TargetLibraryInfoImpl TLII(Triple(M->getTargetTriple()));
+  legacy::PassManager PM;
+  PM.add(new TargetLibraryInfoWrapperPass(TLII));
+  LLVMTargetMachine &LLVMTM = static_cast<LLVMTargetMachine &>(*Target);
+  MachineModuleInfoWrapperPass *MMIWP =
+      new MachineModuleInfoWrapperPass(&LLVMTM);
+  const_cast<TargetLoweringObjectFile *>(LLVMTM.getObjFileLowering())
+      ->Initialize(MMIWP->getMMI().getContext(), *Target);
+
+  SmallString<4096> OutBuffer;
+  raw_svector_ostream OutStream(OutBuffer);
+  if (Target->addPassesToEmitFile(PM, OutStream, nullptr,
+                                  CodeGenFileType::ObjectFile)) {
+    ErrMsg = "Target machine cannot emit a file of this type";
+    return false;
+  }
+
+  PM.run(*M);
+  SpirvObj = OutBuffer.str();
+
+  return true;
+}
diff --git a/llvm/unittests/Target/SPIRV/CMakeLists.txt b/llvm/unittests/Target/SPIRV/CMakeLists.txt
index 83ae215c512ca2..e9fe4883e5b024 100644
--- a/llvm/unittests/Target/SPIRV/CMakeLists.txt
+++ b/llvm/unittests/Target/SPIRV/CMakeLists.txt
@@ -6,6 +6,7 @@ include_directories(
 set(LLVM_LINK_COMPONENTS
   Analysis
   AsmParser
+  BinaryFormat
   Core
   SPIRVCodeGen
   SPIRVAnalysis
@@ -14,5 +15,6 @@ set(LLVM_LINK_COMPONENTS
 
 add_llvm_target_unittest(SPIRVTests
   SPIRVConvergenceRegionAnalysisTests.cpp
+  SPIRVAPITest.cpp
   )
 
diff --git a/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp b/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp
new file mode 100644
index 00000000000000..7bc7d3446762a6
--- /dev/null
+++ b/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp
@@ -0,0 +1,98 @@
+//===- llvm/unittest/CodeGen/SPIRVAPITest.cpp -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// Test that SPIR-V Backend provides an API call that translates LLVM IR Module
+/// into SPIR-V.
+//
+//===----------------------------------------------------------------------===//
+
+// #include "llvm/IR/LegacyPassManager.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/BinaryFormat/Magic.h"
+#include "llvm/IR/Module.h"
+// #include "llvm/MC/TargetRegistry.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/SourceMgr.h"
+// #include "llvm/Support/TargetSelect.h"
+// #include "llvm/Target/TargetMachine.h"
+#include "gtest/gtest.h"
+#include <string>
+#include <utility>
+
+namespace llvm {
+
+extern "C" bool SPIRVTranslateModule(Module *M, std::string &Buffer,
+                                     std::string &ErrMsg,
+                                     const std::vector<std::string> &Opts);
+
+class SPIRVAPITest : public testing::Test {
+protected:
+  /*
+    void SetUp() override {
+      EXPECT_TRUE(Status && Error.empty() && !Result.empty());
+    }
+  */
+
+  bool toSpirv(StringRef Assembly, std::string &Result, std::string &ErrMsg,
+               const std::vector<std::string> &Opts) {
+    SMDiagnostic ParseError;
+    M = parseAssemblyString(Assembly, ParseError, Context);
+    if (!M) {
+      ParseError.print("IR parsing failed: ", errs());
+      report_fatal_error("Can't parse input assembly.");
+    }
+    return SPIRVTranslateModule(M.get(), Result, ErrMsg, Opts);
+  }
+
+  LLVMContext Context;
+  std::unique_ptr<Module> M;
+};
+
+TEST_F(SPIRVAPITest, checkTranslateExtError) {
+  StringRef Assembly = R"(
+    define dso_local spir_func void @test1() {
+    entry:
+      %res1 = tail call spir_func i32 @_Z26__spirv_GroupBitwiseAndKHR(i32 2, i32 0, i32 0)
+      ret void
+    }
+
+    declare dso_local spir_func i32  @_Z26__spirv_GroupBitwiseAndKHR(i32, i32, i32)
+  )";
+  std::string Result, Error;
+  std::vector<std::string> Opts;
+  bool Status = toSpirv(Assembly, Result, Error, Opts);
+  EXPECT_TRUE(Status && Error.empty() && !Result.empty());
+  EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
+}
+
+TEST_F(SPIRVAPITest, checkTranslateOk) {
+  StringRef Assemblies[] = {"", R"(
+    %struct = type { [1 x i64] }
+
+    define spir_kernel void @foo(ptr noundef byval(%struct) %arg) {
+    entry:
+      call spir_func void @bar(<2 x i32> noundef <i32 0, i32 1>)
+      ret void
+    }
+
+    define spir_func void @bar(<2 x i32> noundef) {
+    entry:
+      ret void
+    }
+  )"};
+  for (StringRef &Assembly : Assemblies) {
+    std::string Result, Error;
+    std::vector<std::string> Opts;
+    bool Status = toSpirv(Assembly, Result, Error, Opts);
+    EXPECT_TRUE(Status && Error.empty() && !Result.empty());
+    EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
+  }
+}
+
+} // end namespace llvm

>From 2806e8b8074567ddf9ca9c53e82c80ab7a161197 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 4 Sep 2024 05:22:17 -0700
Subject: [PATCH 2/5] add tests and command line args support

---
 llvm/lib/Target/SPIRV/SPIRV.cpp              | 33 ++++------
 llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp | 67 ++++++++++++++------
 2 files changed, 58 insertions(+), 42 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRV.cpp b/llvm/lib/Target/SPIRV/SPIRV.cpp
index 2636efb26128bf..805887c56429d8 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRV.cpp
@@ -8,48 +8,28 @@
 
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/CodeGen/CommandFlags.h"
-// #include "llvm/CodeGen/LinkAllAsmWriterComponents.h"
-// #include "llvm/CodeGen/LinkAllCodegenComponents.h"
-// #include "llvm/CodeGen/MIRParser/MIRParser.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
 #include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/CodeGen/TargetSubtargetInfo.h"
-// #include "llvm/IR/AutoUpgrade.h"
 #include "llvm/IR/DataLayout.h"
-// #include "llvm/IR/DiagnosticInfo.h"
-// #include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/LLVMContext.h"
-// #include "llvm/IR/LLVMRemarkStreamer.h"
 #include "llvm/IR/LegacyPassManager.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Verifier.h"
-// #include "llvm/IRReader/IRReader.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/MC/MCTargetOptionsCommandFlags.h"
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Pass.h"
-// #include "llvm/Remarks/HotnessThresholdParser.h"
 #include "llvm/Support/CommandLine.h"
-// #include "llvm/Support/Debug.h"
-// #include "llvm/Support/FileSystem.h"
 #include "llvm/Support/FormattedStream.h"
 #include "llvm/Support/InitLLVM.h"
-// #include "llvm/Support/PluginLoader.h"
-// #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/TargetSelect.h"
-// #include "llvm/Support/TimeProfiler.h"
-// #include "llvm/Support/ToolOutputFile.h"
-// #include "llvm/Support/WithColor.h"
 #include "llvm/Target/TargetLoweringObjectFile.h"
 #include "llvm/Target/TargetMachine.h"
-// #include "llvm/TargetParser/Host.h"
 #include "llvm/TargetParser/SubtargetFeature.h"
 #include "llvm/TargetParser/Triple.h"
-// #include "llvm/Transforms/Utils/Cloning.h"
-// #include <memory>
 #include <optional>
-// #include <ostream>
 #include <string>
 #include <utility>
 
@@ -61,7 +41,7 @@ void parseSPIRVCommandLineOptions(const std::vector<std::string> &Options,
   static constexpr const char *Origin = "SPIRVTranslateModule";
   if (!Options.empty()) {
     std::vector<const char *> Argv(1, Origin);
-    for (const auto& Arg : Options)
+    for (const auto &Arg : Options)
       Argv.push_back(Arg.c_str());
     cl::ParseCommandLineOptions(Argv.size(), Argv.data(), Origin, Errs);
   }
@@ -78,6 +58,13 @@ void InitializeSPIRVTarget() {
 }
 } // namespace
 
+namespace llvm {
+
+// The goal of this function is to facilitate integration of SPIRV Backend into
+// tools and libraries by means of exposing an API call that translate LLVM
+// module to SPIR-V and write results into a string as binary SPIR-V output,
+// providing diagnostics on fail and means of configuring translation in a style
+// of command line options.
 extern "C" LLVM_EXTERNAL_VISIBILITY bool
 SPIRVTranslateModule(Module *M, std::string &SpirvObj, std::string &ErrMsg,
                      const std::vector<std::string> &Opts) {
@@ -86,7 +73,7 @@ SPIRVTranslateModule(Module *M, std::string &SpirvObj, std::string &ErrMsg,
   static const std::string DefaultMArch = "";
   static const llvm::CodeGenOptLevel OLevel = llvm::CodeGenOptLevel::None;
 
-  // Parse Opts as if it'd be command line argument.
+  // Parse Opts as if it'd be command line arguments.
   std::string Errors;
   raw_string_ostream ErrorStream(Errors);
   parseSPIRVCommandLineOptions(Opts, &ErrorStream);
@@ -158,3 +145,5 @@ SPIRVTranslateModule(Module *M, std::string &SpirvObj, std::string &ErrMsg,
 
   return true;
 }
+
+} // namespace llvm
diff --git a/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp b/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp
index 7bc7d3446762a6..8561959d890f8a 100644
--- a/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp
+++ b/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp
@@ -12,15 +12,10 @@
 //
 //===----------------------------------------------------------------------===//
 
-// #include "llvm/IR/LegacyPassManager.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/BinaryFormat/Magic.h"
 #include "llvm/IR/Module.h"
-// #include "llvm/MC/TargetRegistry.h"
-#include "llvm/Pass.h"
 #include "llvm/Support/SourceMgr.h"
-// #include "llvm/Support/TargetSelect.h"
-// #include "llvm/Target/TargetMachine.h"
 #include "gtest/gtest.h"
 #include <string>
 #include <utility>
@@ -33,12 +28,6 @@ extern "C" bool SPIRVTranslateModule(Module *M, std::string &Buffer,
 
 class SPIRVAPITest : public testing::Test {
 protected:
-  /*
-    void SetUp() override {
-      EXPECT_TRUE(Status && Error.empty() && !Result.empty());
-    }
-  */
-
   bool toSpirv(StringRef Assembly, std::string &Result, std::string &ErrMsg,
                const std::vector<std::string> &Opts) {
     SMDiagnostic ParseError;
@@ -52,10 +41,8 @@ class SPIRVAPITest : public testing::Test {
 
   LLVMContext Context;
   std::unique_ptr<Module> M;
-};
 
-TEST_F(SPIRVAPITest, checkTranslateExtError) {
-  StringRef Assembly = R"(
+  static constexpr StringRef ExtensionAssembly = R"(
     define dso_local spir_func void @test1() {
     entry:
       %res1 = tail call spir_func i32 @_Z26__spirv_GroupBitwiseAndKHR(i32 2, i32 0, i32 0)
@@ -64,12 +51,7 @@ TEST_F(SPIRVAPITest, checkTranslateExtError) {
 
     declare dso_local spir_func i32  @_Z26__spirv_GroupBitwiseAndKHR(i32, i32, i32)
   )";
-  std::string Result, Error;
-  std::vector<std::string> Opts;
-  bool Status = toSpirv(Assembly, Result, Error, Opts);
-  EXPECT_TRUE(Status && Error.empty() && !Result.empty());
-  EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
-}
+};
 
 TEST_F(SPIRVAPITest, checkTranslateOk) {
   StringRef Assemblies[] = {"", R"(
@@ -95,4 +77,49 @@ TEST_F(SPIRVAPITest, checkTranslateOk) {
   }
 }
 
+TEST_F(SPIRVAPITest, checkTranslateSupportExtension) {
+  std::string Result, Error;
+  std::vector<std::string> Opts{
+      "--spirv-ext=+SPV_KHR_uniform_group_instructions"};
+  bool Status = toSpirv(ExtensionAssembly, Result, Error, Opts);
+  EXPECT_TRUE(Status && Error.empty() && !Result.empty());
+  EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
+}
+
+TEST_F(SPIRVAPITest, checkTranslateAllExtensions) {
+  std::string Result, Error;
+  std::vector<std::string> Opts{"--spirv-ext=all"};
+  bool Status = toSpirv(ExtensionAssembly, Result, Error, Opts);
+  EXPECT_TRUE(Status && Error.empty() && !Result.empty());
+  EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
+}
+
+#if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
+TEST_F(SPIRVAPITest, checkTranslateExtensionError) {
+  std::string Result, Error;
+  std::vector<std::string> Opts;
+  EXPECT_DEATH_IF_SUPPORTED(
+      { toSpirv(ExtensionAssembly, Result, Error, Opts); },
+      "LLVM ERROR: __spirv_GroupBitwiseAndKHR: the builtin requires the "
+      "following SPIR-V extension: SPV_KHR_uniform_group_instructions");
+}
+
+TEST_F(SPIRVAPITest, checkTranslateUnknownExtension) {
+  std::string Result, Error;
+  std::vector<std::string> Opts{"--spirv-ext=+SPV_XYZ_my_unknown_extension"};
+  EXPECT_DEATH_IF_SUPPORTED(
+      { toSpirv(ExtensionAssembly, Result, Error, Opts); },
+      "SPIRVTranslateModule: for the --spirv-ext option: Unknown SPIR-V");
+}
+
+TEST_F(SPIRVAPITest, checkTranslateWrongExtension) {
+  std::string Result, Error;
+  std::vector<std::string> Opts{"--spirv-ext=+SPV_KHR_subgroup_rotate"};
+  EXPECT_DEATH_IF_SUPPORTED(
+      { toSpirv(ExtensionAssembly, Result, Error, Opts); },
+      "LLVM ERROR: __spirv_GroupBitwiseAndKHR: the builtin requires the "
+      "following SPIR-V extension: SPV_KHR_uniform_group_instructions");
+}
+#endif
+
 } // end namespace llvm

>From be3019f1b1a05cd606d1bcdccaec1d8f9b9f5301 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 4 Sep 2024 06:55:18 -0700
Subject: [PATCH 3/5] expose an API call to initialize SPIRV target and
 translate input LLVM IR module to SPIR-V

---
 llvm/lib/Target/SPIRV/CMakeLists.txt          |  2 +-
 llvm/lib/Target/SPIRV/SPIRV.h                 |  1 +
 .../Target/SPIRV/{SPIRV.cpp => SPIRVAPI.cpp}  | 26 ++++++++--
 llvm/lib/Target/SPIRV/SPIRVAPI.h              | 23 +++++++++
 llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp  | 47 ++++++++++++++-----
 5 files changed, 82 insertions(+), 17 deletions(-)
 rename llvm/lib/Target/SPIRV/{SPIRV.cpp => SPIRVAPI.cpp} (85%)
 create mode 100644 llvm/lib/Target/SPIRV/SPIRVAPI.h

diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index 2c5c053fe72173..df7869b1552caa 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -14,7 +14,7 @@ tablegen(LLVM SPIRVGenTables.inc -gen-searchable-tables)
 add_public_tablegen_target(SPIRVCommonTableGen)
 
 add_llvm_target(SPIRVCodeGen
-  SPIRV.cpp
+  SPIRVAPI.cpp
   SPIRVAsmPrinter.cpp
   SPIRVBuiltins.cpp
   SPIRVCallLowering.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index 6c35a467f53bef..7a7b3827b5e4a1 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -12,6 +12,7 @@
 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/Target/TargetMachine.h"
+#include "SPIRVAPI.h"
 
 namespace llvm {
 class SPIRVTargetMachine;
diff --git a/llvm/lib/Target/SPIRV/SPIRV.cpp b/llvm/lib/Target/SPIRV/SPIRVAPI.cpp
similarity index 85%
rename from llvm/lib/Target/SPIRV/SPIRV.cpp
rename to llvm/lib/Target/SPIRV/SPIRVAPI.cpp
index 805887c56429d8..b4ada1947a4888 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVAPI.cpp
@@ -1,4 +1,4 @@
-//===-- SPIRV.cpp - SPIR-V Backend API ------------------------*- C++ -*---===//
+//===-- SPIRVAPI.cpp - SPIR-V Backend API ---------------------*- C++ -*---===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -32,10 +32,19 @@
 #include <optional>
 #include <string>
 #include <utility>
+#include <vector>
 
 using namespace llvm;
 
 namespace {
+
+// Mimic limited number of command line flags from llc to provide a better
+// user experience when passing options into the translate API call.
+static cl::opt<char> SpvOptLevel(" O", cl::Hidden, cl::Prefix, cl::init('0'));
+static cl::opt<std::string> SpvTargetTriple(" mtriple", cl::Hidden,
+                                            cl::init(""));
+
+// Utility to accept options in a command line style.
 void parseSPIRVCommandLineOptions(const std::vector<std::string> &Options,
                                   raw_ostream *Errs) {
   static constexpr const char *Origin = "SPIRVTranslateModule";
@@ -68,10 +77,9 @@ namespace llvm {
 extern "C" LLVM_EXTERNAL_VISIBILITY bool
 SPIRVTranslateModule(Module *M, std::string &SpirvObj, std::string &ErrMsg,
                      const std::vector<std::string> &Opts) {
-  // Fallbacks for a Triple, MArch, Opt-level values.
+  // Fallbacks for option values.
   static const std::string DefaultTriple = "spirv64-unknown-unknown";
   static const std::string DefaultMArch = "";
-  static const llvm::CodeGenOptLevel OLevel = llvm::CodeGenOptLevel::None;
 
   // Parse Opts as if it'd be command line arguments.
   std::string Errors;
@@ -82,10 +90,20 @@ SPIRVTranslateModule(Module *M, std::string &SpirvObj, std::string &ErrMsg,
     return false;
   }
 
+  llvm::CodeGenOptLevel OLevel;
+  if (auto Level = CodeGenOpt::parseLevel(SpvOptLevel)) {
+    OLevel = *Level;
+  } else {
+    ErrMsg = "Invalid optimization level!";
+    return false;
+  }
+
   // SPIR-V-specific target initialization.
   InitializeSPIRVTarget();
 
-  Triple TargetTriple(M->getTargetTriple());
+  Triple TargetTriple(SpvTargetTriple.empty()
+                          ? M->getTargetTriple()
+                          : Triple::normalize(SpvTargetTriple));
   if (TargetTriple.getTriple().empty()) {
     TargetTriple.setTriple(DefaultTriple);
     M->setTargetTriple(DefaultTriple);
diff --git a/llvm/lib/Target/SPIRV/SPIRVAPI.h b/llvm/lib/Target/SPIRV/SPIRVAPI.h
new file mode 100644
index 00000000000000..c3786c6975a890
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVAPI.h
@@ -0,0 +1,23 @@
+//===-- SPIRVAPI.h - SPIR-V Backend API interface ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVAPI_H
+#define LLVM_LIB_TARGET_SPIRV_SPIRVAPI_H
+
+#include <string>
+#include <vector>
+
+namespace llvm {
+class Module;
+
+extern "C" bool SPIRVTranslateModule(Module *M, std::string &Buffer,
+                                     std::string &ErrMsg,
+                                     const std::vector<std::string> &Opts);
+} // namespace llvm
+
+#endif // LLVM_LIB_TARGET_SPIRV_SPIRVAPI_H
diff --git a/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp b/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp
index 8561959d890f8a..a9bf9fb43f8f4f 100644
--- a/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp
+++ b/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp
@@ -17,9 +17,12 @@
 #include "llvm/IR/Module.h"
 #include "llvm/Support/SourceMgr.h"
 #include "gtest/gtest.h"
+#include <gmock/gmock.h>
 #include <string>
 #include <utility>
 
+using ::testing::StartsWith;
+
 namespace llvm {
 
 extern "C" bool SPIRVTranslateModule(Module *M, std::string &Buffer,
@@ -36,7 +39,10 @@ class SPIRVAPITest : public testing::Test {
       ParseError.print("IR parsing failed: ", errs());
       report_fatal_error("Can't parse input assembly.");
     }
-    return SPIRVTranslateModule(M.get(), Result, ErrMsg, Opts);
+    bool Status = SPIRVTranslateModule(M.get(), Result, ErrMsg, Opts);
+    if (!Status)
+      errs() << ErrMsg;
+    return Status;
   }
 
   LLVMContext Context;
@@ -51,10 +57,7 @@ class SPIRVAPITest : public testing::Test {
 
     declare dso_local spir_func i32  @_Z26__spirv_GroupBitwiseAndKHR(i32, i32, i32)
   )";
-};
-
-TEST_F(SPIRVAPITest, checkTranslateOk) {
-  StringRef Assemblies[] = {"", R"(
+  static constexpr StringRef OkAssembly = R"(
     %struct = type { [1 x i64] }
 
     define spir_kernel void @foo(ptr noundef byval(%struct) %arg) {
@@ -67,16 +70,36 @@ TEST_F(SPIRVAPITest, checkTranslateOk) {
     entry:
       ret void
     }
-  )"};
-  for (StringRef &Assembly : Assemblies) {
-    std::string Result, Error;
-    std::vector<std::string> Opts;
-    bool Status = toSpirv(Assembly, Result, Error, Opts);
-    EXPECT_TRUE(Status && Error.empty() && !Result.empty());
-    EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
+  )";
+};
+
+TEST_F(SPIRVAPITest, checkTranslateOk) {
+  StringRef Assemblies[] = {"", OkAssembly};
+  // Those command line arguments that overlap with registered by llc/codegen
+  // are to be started with the ' ' symbol.
+  std::vector<std::string> SetOfOpts[] = {
+      {}, {"- mtriple=spirv32-unknown-unknown"}};
+  for (const auto &Opts : SetOfOpts) {
+    for (StringRef &Assembly : Assemblies) {
+      std::string Result, Error;
+      bool Status = toSpirv(Assembly, Result, Error, Opts);
+      EXPECT_TRUE(Status && Error.empty() && !Result.empty());
+      EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
+    }
   }
 }
 
+TEST_F(SPIRVAPITest, checkTranslateError) {
+  std::string Result, Error;
+  bool Status =
+      toSpirv(OkAssembly, Result, Error, {"-mtriple=spirv32-unknown-unknown"});
+  EXPECT_FALSE(Status);
+  EXPECT_TRUE(Result.empty());
+  EXPECT_THAT(Error,
+              StartsWith("SPIRVTranslateModule: Unknown command line argument "
+                         "'-mtriple=spirv32-unknown-unknown'"));
+}
+
 TEST_F(SPIRVAPITest, checkTranslateSupportExtension) {
   std::string Result, Error;
   std::vector<std::string> Opts{

>From a11d3cea894a9be2bf7e3c357e9c70fc10743f7f Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 4 Sep 2024 07:01:47 -0700
Subject: [PATCH 4/5] harden the unit test

---
 llvm/lib/Target/SPIRV/SPIRV.h                | 1 -
 llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp | 4 ++++
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index 7a7b3827b5e4a1..6c35a467f53bef 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -12,7 +12,6 @@
 #include "MCTargetDesc/SPIRVMCTargetDesc.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/Target/TargetMachine.h"
-#include "SPIRVAPI.h"
 
 namespace llvm {
 class SPIRVTargetMachine;
diff --git a/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp b/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp
index a9bf9fb43f8f4f..d58c1f3fe9b460 100644
--- a/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp
+++ b/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp
@@ -98,6 +98,10 @@ TEST_F(SPIRVAPITest, checkTranslateError) {
   EXPECT_THAT(Error,
               StartsWith("SPIRVTranslateModule: Unknown command line argument "
                          "'-mtriple=spirv32-unknown-unknown'"));
+  Status = toSpirv(OkAssembly, Result, Error, {"- O 5"});
+  EXPECT_FALSE(Status);
+  EXPECT_TRUE(Result.empty());
+  EXPECT_EQ(Error, "Invalid optimization level!");
 }
 
 TEST_F(SPIRVAPITest, checkTranslateSupportExtension) {

>From b2023ca443d49debc55b2b049a70adf35971cf53 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 10 Sep 2024 03:22:35 -0700
Subject: [PATCH 5/5] provide a list of allowed extensions as the API call
 argument

---
 llvm/lib/Target/SPIRV/SPIRVAPI.cpp           | 14 ++++
 llvm/lib/Target/SPIRV/SPIRVAPI.h             |  7 +-
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp   | 12 ++++
 llvm/lib/Target/SPIRV/SPIRVCommandLine.h     | 10 +++
 llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp     |  8 +++
 llvm/lib/Target/SPIRV/SPIRVSubtarget.h       |  6 ++
 llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp | 71 +++++++++++++++-----
 7 files changed, 109 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVAPI.cpp b/llvm/lib/Target/SPIRV/SPIRVAPI.cpp
index b4ada1947a4888..a6720d63c63b88 100644
--- a/llvm/lib/Target/SPIRV/SPIRVAPI.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVAPI.cpp
@@ -6,6 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "SPIRVCommandLine.h"
+#include "SPIRVSubtarget.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/CodeGen/CommandFlags.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
@@ -76,6 +78,7 @@ namespace llvm {
 // of command line options.
 extern "C" LLVM_EXTERNAL_VISIBILITY bool
 SPIRVTranslateModule(Module *M, std::string &SpirvObj, std::string &ErrMsg,
+                     const std::vector<std::string> &AllowExtNames,
                      const std::vector<std::string> &Opts) {
   // Fallbacks for option values.
   static const std::string DefaultTriple = "spirv64-unknown-unknown";
@@ -98,6 +101,17 @@ SPIRVTranslateModule(Module *M, std::string &SpirvObj, std::string &ErrMsg,
     return false;
   }
 
+  // Overrides/ammends `-spirv-ext` command line switch (if present) by the
+  // explicit list of allowed SPIR-V extensions.
+  std::set<SPIRV::Extension::Extension> AllowedExtIds;
+  StringRef UnknownExt =
+      SPIRVExtensionsParser::checkExtensions(AllowExtNames, AllowedExtIds);
+  if (!UnknownExt.empty()) {
+    ErrMsg = "Unknown SPIR-V extension: " + UnknownExt.str();
+    return false;
+  }
+  SPIRVSubtarget::addExtensionsToClOpt(AllowedExtIds);
+
   // SPIR-V-specific target initialization.
   InitializeSPIRVTarget();
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVAPI.h b/llvm/lib/Target/SPIRV/SPIRVAPI.h
index c3786c6975a890..cd41b8e595a654 100644
--- a/llvm/lib/Target/SPIRV/SPIRVAPI.h
+++ b/llvm/lib/Target/SPIRV/SPIRVAPI.h
@@ -15,9 +15,10 @@
 namespace llvm {
 class Module;
 
-extern "C" bool SPIRVTranslateModule(Module *M, std::string &Buffer,
-                                     std::string &ErrMsg,
-                                     const std::vector<std::string> &Opts);
+extern "C" bool
+SPIRVTranslateModule(Module *M, std::string &SpirvObj, std::string &ErrMsg,
+                     const std::vector<std::string> &AllowExtNames,
+                     const std::vector<std::string> &Opts);
 } // namespace llvm
 
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVAPI_H
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 90a9ab1d33ced4..127585f85915fb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -113,3 +113,15 @@ bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName,
   Vals = std::move(EnabledExtensions);
   return false;
 }
+
+llvm::StringRef SPIRVExtensionsParser::checkExtensions(
+    const std::vector<std::string> &ExtNames,
+    std::set<SPIRV::Extension::Extension> &AllowedExtensions) {
+  for (const auto &Ext : ExtNames) {
+    auto It = SPIRVExtensionMap.find(Ext);
+    if (It == SPIRVExtensionMap.end())
+      return Ext;
+    AllowedExtensions.insert(It->second);
+  }
+  return StringRef();
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.h b/llvm/lib/Target/SPIRV/SPIRVCommandLine.h
index 741d829b2ab8f9..8df2968eb6fe12 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.h
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.h
@@ -17,8 +17,10 @@
 #include "MCTargetDesc/SPIRVBaseInfo.h"
 #include "llvm/Support/CommandLine.h"
 #include <set>
+#include <string>
 
 namespace llvm {
+class StringRef;
 
 /// Command line parser for toggling SPIR-V extensions.
 struct SPIRVExtensionsParser
@@ -32,6 +34,14 @@ struct SPIRVExtensionsParser
   /// \return Returns true on error.
   bool parse(cl::Option &O, StringRef ArgName, StringRef ArgValue,
              std::set<SPIRV::Extension::Extension> &Vals);
+
+  /// Validates and converts extension names into internal enum values.
+  ///
+  /// \return Returns a reference to the unknown SPIR-V extension name from the
+  /// list if present, or an empty StringRef on success.
+  static llvm::StringRef
+  checkExtensions(const std::vector<std::string> &ExtNames,
+                  std::set<SPIRV::Extension::Extension> &AllowedExtensions);
 };
 
 } // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index 27472923ee08c8..883ad14a20cb8e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -38,6 +38,14 @@ static cl::opt<std::set<SPIRV::Extension::Extension>, false,
     Extensions("spirv-ext",
                cl::desc("Specify list of enabled SPIR-V extensions"));
 
+// Provides access to the cl::opt<...> `Extensions` variable from outside of the
+// module.
+void SPIRVSubtarget::addExtensionsToClOpt(
+    const std::set<SPIRV::Extension::Extension> &AllowList) {
+  for (const auto &Ext : AllowList)
+    Extensions.insert(Ext);
+}
+
 // Compare version numbers, but allow 0 to mean unspecified.
 static bool isAtLeastVer(VersionTuple Target, VersionTuple VerToCompareTo) {
   return Target.empty() || Target >= VerToCompareTo;
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.h b/llvm/lib/Target/SPIRV/SPIRVSubtarget.h
index 82ec3cc95cdd3f..984ba953e874f5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.h
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.h
@@ -130,6 +130,12 @@ class SPIRVSubtarget : public SPIRVGenSubtargetInfo {
   }
 
   static constexpr unsigned MaxLegalAddressSpace = 6;
+
+  // Adds known SPIR-V extensions to the global list of allowed extensions that
+  // SPIRVSubtarget module owns as
+  // cl::opt<std::set<SPIRV::Extension::Extension>, ...> global variable.
+  static void
+  addExtensionsToClOpt(const std::set<SPIRV::Extension::Extension> &AllowList);
 };
 } // namespace llvm
 
diff --git a/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp b/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp
index d58c1f3fe9b460..27ea8b8cf06e8d 100644
--- a/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp
+++ b/llvm/unittests/Target/SPIRV/SPIRVAPITest.cpp
@@ -25,13 +25,15 @@ using ::testing::StartsWith;
 
 namespace llvm {
 
-extern "C" bool SPIRVTranslateModule(Module *M, std::string &Buffer,
-                                     std::string &ErrMsg,
-                                     const std::vector<std::string> &Opts);
+extern "C" bool
+SPIRVTranslateModule(Module *M, std::string &SpirvObj, std::string &ErrMsg,
+                     const std::vector<std::string> &AllowExtNames,
+                     const std::vector<std::string> &Opts);
 
 class SPIRVAPITest : public testing::Test {
 protected:
   bool toSpirv(StringRef Assembly, std::string &Result, std::string &ErrMsg,
+               const std::vector<std::string> &AllowExtNames,
                const std::vector<std::string> &Opts) {
     SMDiagnostic ParseError;
     M = parseAssemblyString(Assembly, ParseError, Context);
@@ -39,7 +41,8 @@ class SPIRVAPITest : public testing::Test {
       ParseError.print("IR parsing failed: ", errs());
       report_fatal_error("Can't parse input assembly.");
     }
-    bool Status = SPIRVTranslateModule(M.get(), Result, ErrMsg, Opts);
+    bool Status =
+        SPIRVTranslateModule(M.get(), Result, ErrMsg, AllowExtNames, Opts);
     if (!Status)
       errs() << ErrMsg;
     return Status;
@@ -82,7 +85,7 @@ TEST_F(SPIRVAPITest, checkTranslateOk) {
   for (const auto &Opts : SetOfOpts) {
     for (StringRef &Assembly : Assemblies) {
       std::string Result, Error;
-      bool Status = toSpirv(Assembly, Result, Error, Opts);
+      bool Status = toSpirv(Assembly, Result, Error, {}, Opts);
       EXPECT_TRUE(Status && Error.empty() && !Result.empty());
       EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
     }
@@ -91,24 +94,42 @@ TEST_F(SPIRVAPITest, checkTranslateOk) {
 
 TEST_F(SPIRVAPITest, checkTranslateError) {
   std::string Result, Error;
-  bool Status =
-      toSpirv(OkAssembly, Result, Error, {"-mtriple=spirv32-unknown-unknown"});
+  bool Status = toSpirv(OkAssembly, Result, Error, {},
+                        {"-mtriple=spirv32-unknown-unknown"});
   EXPECT_FALSE(Status);
   EXPECT_TRUE(Result.empty());
   EXPECT_THAT(Error,
               StartsWith("SPIRVTranslateModule: Unknown command line argument "
                          "'-mtriple=spirv32-unknown-unknown'"));
-  Status = toSpirv(OkAssembly, Result, Error, {"- O 5"});
+  Status = toSpirv(OkAssembly, Result, Error, {}, {"- O 5"});
   EXPECT_FALSE(Status);
   EXPECT_TRUE(Result.empty());
   EXPECT_EQ(Error, "Invalid optimization level!");
 }
 
-TEST_F(SPIRVAPITest, checkTranslateSupportExtension) {
+TEST_F(SPIRVAPITest, checkTranslateSupportExtensionByOpts) {
   std::string Result, Error;
   std::vector<std::string> Opts{
       "--spirv-ext=+SPV_KHR_uniform_group_instructions"};
-  bool Status = toSpirv(ExtensionAssembly, Result, Error, Opts);
+  bool Status = toSpirv(ExtensionAssembly, Result, Error, {}, Opts);
+  EXPECT_TRUE(Status && Error.empty() && !Result.empty());
+  EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
+}
+
+TEST_F(SPIRVAPITest, checkTranslateSupportExtensionByArg) {
+  std::string Result, Error;
+  std::vector<std::string> ExtNames{"SPV_KHR_uniform_group_instructions"};
+  bool Status = toSpirv(ExtensionAssembly, Result, Error, ExtNames, {});
+  EXPECT_TRUE(Status && Error.empty() && !Result.empty());
+  EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
+}
+
+TEST_F(SPIRVAPITest, checkTranslateSupportExtensionByArgList) {
+  std::string Result, Error;
+  std::vector<std::string> ExtNames{"SPV_KHR_subgroup_rotate",
+                                    "SPV_KHR_uniform_group_instructions",
+                                    "SPV_KHR_subgroup_rotate"};
+  bool Status = toSpirv(ExtensionAssembly, Result, Error, ExtNames, {});
   EXPECT_TRUE(Status && Error.empty() && !Result.empty());
   EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
 }
@@ -116,34 +137,52 @@ TEST_F(SPIRVAPITest, checkTranslateSupportExtension) {
 TEST_F(SPIRVAPITest, checkTranslateAllExtensions) {
   std::string Result, Error;
   std::vector<std::string> Opts{"--spirv-ext=all"};
-  bool Status = toSpirv(ExtensionAssembly, Result, Error, Opts);
+  bool Status = toSpirv(ExtensionAssembly, Result, Error, {}, Opts);
   EXPECT_TRUE(Status && Error.empty() && !Result.empty());
   EXPECT_EQ(identify_magic(Result), file_magic::spirv_object);
 }
 
+TEST_F(SPIRVAPITest, checkTranslateUnknownExtensionByArg) {
+  std::string Result, Error;
+  std::vector<std::string> ExtNames{"SPV_XYZ_my_unknown_extension"};
+  bool Status = toSpirv(ExtensionAssembly, Result, Error, ExtNames, {});
+  EXPECT_FALSE(Status);
+  EXPECT_TRUE(Result.empty());
+  EXPECT_EQ(Error, "Unknown SPIR-V extension: SPV_XYZ_my_unknown_extension");
+}
+
 #if !defined(NDEBUG) && GTEST_HAS_DEATH_TEST
 TEST_F(SPIRVAPITest, checkTranslateExtensionError) {
   std::string Result, Error;
   std::vector<std::string> Opts;
   EXPECT_DEATH_IF_SUPPORTED(
-      { toSpirv(ExtensionAssembly, Result, Error, Opts); },
+      { toSpirv(ExtensionAssembly, Result, Error, {}, Opts); },
       "LLVM ERROR: __spirv_GroupBitwiseAndKHR: the builtin requires the "
       "following SPIR-V extension: SPV_KHR_uniform_group_instructions");
 }
 
-TEST_F(SPIRVAPITest, checkTranslateUnknownExtension) {
+TEST_F(SPIRVAPITest, checkTranslateUnknownExtensionByOpts) {
   std::string Result, Error;
   std::vector<std::string> Opts{"--spirv-ext=+SPV_XYZ_my_unknown_extension"};
   EXPECT_DEATH_IF_SUPPORTED(
-      { toSpirv(ExtensionAssembly, Result, Error, Opts); },
+      { toSpirv(ExtensionAssembly, Result, Error, {}, Opts); },
       "SPIRVTranslateModule: for the --spirv-ext option: Unknown SPIR-V");
 }
 
-TEST_F(SPIRVAPITest, checkTranslateWrongExtension) {
+TEST_F(SPIRVAPITest, checkTranslateWrongExtensionByOpts) {
   std::string Result, Error;
   std::vector<std::string> Opts{"--spirv-ext=+SPV_KHR_subgroup_rotate"};
   EXPECT_DEATH_IF_SUPPORTED(
-      { toSpirv(ExtensionAssembly, Result, Error, Opts); },
+      { toSpirv(ExtensionAssembly, Result, Error, {}, Opts); },
+      "LLVM ERROR: __spirv_GroupBitwiseAndKHR: the builtin requires the "
+      "following SPIR-V extension: SPV_KHR_uniform_group_instructions");
+}
+
+TEST_F(SPIRVAPITest, checkTranslateWrongExtensionByArg) {
+  std::string Result, Error;
+  std::vector<std::string> ExtNames{"SPV_KHR_subgroup_rotate"};
+  EXPECT_DEATH_IF_SUPPORTED(
+      { toSpirv(ExtensionAssembly, Result, Error, ExtNames, {}); },
       "LLVM ERROR: __spirv_GroupBitwiseAndKHR: the builtin requires the "
       "following SPIR-V extension: SPV_KHR_uniform_group_instructions");
 }



More information about the llvm-commits mailing list