[clang] 73417c5 - [HLSL][clang][Driver] Support validator version command line option.

via cfe-commits cfe-commits at lists.llvm.org
Fri Apr 29 16:49:47 PDT 2022


Author: python3kgae
Date: 2022-04-29T16:48:08-07:00
New Revision: 73417c517644db5c419c85c0b3cb6750172fcab5

URL: https://github.com/llvm/llvm-project/commit/73417c517644db5c419c85c0b3cb6750172fcab5
DIFF: https://github.com/llvm/llvm-project/commit/73417c517644db5c419c85c0b3cb6750172fcab5.diff

LOG: [HLSL][clang][Driver] Support validator version command line option.

The DXIL validator version option(/validator-version) decide the validator version when compile hlsl.
The format is major.minor like 1.0.

In normal case, the value of validator version should be got from DXIL validator. Before we got DXIL validator ready for llvm/main, DXIL validator version option is added first to set validator version.

It will affect code generation for DXIL, so it is treated as a code gen option.

A new member std::string DxilValidatorVersion is added to clang::CodeGenOptions.

Then CGHLSLRuntime is added to clang::CodeGenModule.
It is used to translate clang::CodeGenOptions::DxilValidatorVersion into a ModuleFlag under key "dx.valver" at end of clang code generation.

Reviewed By: beanz

Differential Revision: https://reviews.llvm.org/D123884

Added: 
    clang/lib/CodeGen/CGHLSLRuntime.cpp
    clang/lib/CodeGen/CGHLSLRuntime.h
    clang/test/CodeGenHLSL/validator_version.hlsl

Modified: 
    clang/include/clang/Basic/DiagnosticDriverKinds.td
    clang/include/clang/Basic/TargetOptions.h
    clang/include/clang/Driver/Options.td
    clang/lib/CodeGen/CMakeLists.txt
    clang/lib/CodeGen/CodeGenModule.cpp
    clang/lib/CodeGen/CodeGenModule.h
    clang/lib/Driver/ToolChains/Clang.cpp
    clang/lib/Driver/ToolChains/HLSL.cpp
    clang/lib/Driver/ToolChains/HLSL.h
    clang/unittests/Driver/ToolChainTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/DiagnosticDriverKinds.td b/clang/include/clang/Basic/DiagnosticDriverKinds.td
index 7ab7a8c0cd175..b35693462e33d 100644
--- a/clang/include/clang/Basic/DiagnosticDriverKinds.td
+++ b/clang/include/clang/Basic/DiagnosticDriverKinds.td
@@ -667,4 +667,13 @@ def err_drv_target_variant_invalid : Error<
 def err_drv_invalid_directx_shader_module : Error<
   "invalid profile : %0">;
 
+def err_drv_invalid_range_dxil_validator_version : Error<
+  "invalid validator version : %0\n"
+  "Validator version must be less than or equal to current internal version.">;
+def err_drv_invalid_format_dxil_validator_version : Error<
+  "invalid validator version : %0\n"
+  "Format of validator version is \"<major>.<minor>\" (ex:\"1.4\").">;
+def err_drv_invalid_empty_dxil_validator_version : Error<
+  "invalid validator version : %0\n"
+  "If validator major version is 0, minor version must also be 0.">;
 }

diff  --git a/clang/include/clang/Basic/TargetOptions.h b/clang/include/clang/Basic/TargetOptions.h
index 009f25981ca93..611add6f92682 100644
--- a/clang/include/clang/Basic/TargetOptions.h
+++ b/clang/include/clang/Basic/TargetOptions.h
@@ -110,8 +110,11 @@ class TargetOptions {
   /// The version of the darwin target variant SDK which was used during the
   /// compilation.
   llvm::VersionTuple DarwinTargetVariantSDKVersion;
+
+  /// The validator version for dxil.
+  std::string DxilValidatorVersion;
 };
 
-}  // end namespace clang
+} // end namespace clang
 
 #endif

diff  --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td
index 15b94ee5425e9..ae95ab267c163 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -6736,20 +6736,21 @@ def _SLASH_ZW : CLJoined<"ZW">;
 
 def dxc_Group : OptionGroup<"<clang-dxc options>">, Flags<[DXCOption]>,
   HelpText<"dxc compatibility options">;
-
 class DXCJoinedOrSeparate<string name> : Option<["/", "-"], name,
   KIND_JOINED_OR_SEPARATE>, Group<dxc_Group>, Flags<[DXCOption, NoXarchOption]>;
 
 def dxc_help : Option<["/", "-", "--"], "help", KIND_JOINED>,
   Group<dxc_Group>, Flags<[DXCOption, NoXarchOption]>, Alias<help>,
   HelpText<"Display available options">;
-
-
 def Fo : DXCJoinedOrSeparate<"Fo">, Alias<o>,
-  HelpText<"Output object file.">;
-
+  HelpText<"Output object file">;
+def dxil_validator_version : Option<["/", "-"], "validator-version", KIND_SEPARATE>,
+  Group<dxc_Group>, Flags<[DXCOption, NoXarchOption, CC1Option, HelpHidden]>,
+  HelpText<"Override validator version for module. Format: <major.minor>;"
+           "Default: DXIL.dll version or current internal version">,
+  MarshallingInfoString<TargetOpts<"DxilValidatorVersion">>;
 def target_profile : DXCJoinedOrSeparate<"T">, MetaVarName<"<profile>">,
-  HelpText<"Set target profile.">,
+  HelpText<"Set target profile">,
   Values<"ps_6_0, ps_6_1, ps_6_2, ps_6_3, ps_6_4, ps_6_5, ps_6_6, ps_6_7,"
          "vs_6_0, vs_6_1, vs_6_2, vs_6_3, vs_6_4, vs_6_5, vs_6_6, vs_6_7,"
          "gs_6_0, gs_6_1, gs_6_2, gs_6_3, gs_6_4, gs_6_5, gs_6_6, gs_6_7,"

diff  --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
new file mode 100644
index 0000000000000..f5392213f9fe2
--- /dev/null
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -0,0 +1,52 @@
+//===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides an abstract class for HLSL code generation.  Concrete
+// subclasses of this implement code generation for specific HLSL
+// runtime libraries.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CGHLSLRuntime.h"
+#include "CodeGenModule.h"
+#include "clang/Basic/TargetOptions.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+
+using namespace clang;
+using namespace CodeGen;
+using namespace llvm;
+
+namespace {
+void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
+  // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
+  // Assume ValVersionStr is legal here.
+  VersionTuple Version;
+  if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
+      Version.getSubminor() || !Version.getMinor()) {
+    return;
+  }
+
+  uint64_t Major = Version.getMajor();
+  uint64_t Minor = Version.getMinor().getValue();
+
+  auto &Ctx = M.getContext();
+  IRBuilder<> B(M.getContext());
+  MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
+                                  ConstantAsMetadata::get(B.getInt32(Minor))});
+  StringRef DxilValKey = "dx.valver";
+  M.addModuleFlag(llvm::Module::ModFlagBehavior::AppendUnique, DxilValKey, Val);
+}
+} // namespace
+
+void CGHLSLRuntime::finishCodeGen() {
+  auto &TargetOpts = CGM.getTarget().getTargetOpts();
+
+  llvm::Module &M = CGM.getModule();
+  addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
+}

diff  --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
new file mode 100644
index 0000000000000..268810f2ec9e6
--- /dev/null
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -0,0 +1,38 @@
+//===----- CGHLSLRuntime.h - Interface to HLSL Runtimes -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides an abstract class for HLSL code generation.  Concrete
+// subclasses of this implement code generation for specific HLSL
+// runtime libraries.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_LIB_CODEGEN_CGHLSLRUNTIME_H
+#define LLVM_CLANG_LIB_CODEGEN_CGHLSLRUNTIME_H
+
+namespace clang {
+
+namespace CodeGen {
+
+class CodeGenModule;
+
+class CGHLSLRuntime {
+protected:
+  CodeGenModule &CGM;
+
+public:
+  CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM) {}
+  virtual ~CGHLSLRuntime() {}
+
+  void finishCodeGen();
+};
+
+} // namespace CodeGen
+} // namespace clang
+
+#endif

diff  --git a/clang/lib/CodeGen/CMakeLists.txt b/clang/lib/CodeGen/CMakeLists.txt
index c0b486daae3ab..0bb5abcf60455 100644
--- a/clang/lib/CodeGen/CMakeLists.txt
+++ b/clang/lib/CodeGen/CMakeLists.txt
@@ -51,6 +51,7 @@ add_clang_library(clangCodeGen
   CGExprConstant.cpp
   CGExprScalar.cpp
   CGGPUBuiltin.cpp
+  CGHLSLRuntime.cpp
   CGLoopInfo.cpp
   CGNonTrivialStruct.cpp
   CGObjC.cpp

diff  --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp
index 83f650cb8161d..f8bf210dc0e21 100644
--- a/clang/lib/CodeGen/CodeGenModule.cpp
+++ b/clang/lib/CodeGen/CodeGenModule.cpp
@@ -16,6 +16,7 @@
 #include "CGCXXABI.h"
 #include "CGCall.h"
 #include "CGDebugInfo.h"
+#include "CGHLSLRuntime.h"
 #include "CGObjCRuntime.h"
 #include "CGOpenCLRuntime.h"
 #include "CGOpenMPRuntime.h"
@@ -146,6 +147,8 @@ CodeGenModule::CodeGenModule(ASTContext &C, const HeaderSearchOptions &HSO,
     createOpenMPRuntime();
   if (LangOpts.CUDA)
     createCUDARuntime();
+  if (LangOpts.HLSL)
+    createHLSLRuntime();
 
   // Enable TBAA unless it's suppressed. ThreadSanitizer needs TBAA even at O0.
   if (LangOpts.Sanitize.has(SanitizerKind::Thread) ||
@@ -262,6 +265,10 @@ void CodeGenModule::createCUDARuntime() {
   CUDARuntime.reset(CreateNVCUDARuntime(*this));
 }
 
+void CodeGenModule::createHLSLRuntime() {
+  HLSLRuntime.reset(new CGHLSLRuntime(*this));
+}
+
 void CodeGenModule::addReplacement(StringRef Name, llvm::Constant *C) {
   Replacements[Name] = C;
 }
@@ -832,6 +839,10 @@ void CodeGenModule::Release() {
     }
   }
 
+  // HLSL related end of code gen work items.
+  if (LangOpts.HLSL)
+    getHLSLRuntime().finishCodeGen();
+
   if (uint32_t PLevel = Context.getLangOpts().PICLevel) {
     assert(PLevel < 3 && "Invalid PIC Level");
     getModule().setPICLevel(static_cast<llvm::PICLevel::Level>(PLevel));

diff  --git a/clang/lib/CodeGen/CodeGenModule.h b/clang/lib/CodeGen/CodeGenModule.h
index 1ba592bab6fc6..8393d43682ea5 100644
--- a/clang/lib/CodeGen/CodeGenModule.h
+++ b/clang/lib/CodeGen/CodeGenModule.h
@@ -85,6 +85,7 @@ class CGObjCRuntime;
 class CGOpenCLRuntime;
 class CGOpenMPRuntime;
 class CGCUDARuntime;
+class CGHLSLRuntime;
 class CoverageMappingModuleGen;
 class TargetCodeGenInfo;
 
@@ -319,6 +320,7 @@ class CodeGenModule : public CodeGenTypeCache {
   std::unique_ptr<CGOpenCLRuntime> OpenCLRuntime;
   std::unique_ptr<CGOpenMPRuntime> OpenMPRuntime;
   std::unique_ptr<CGCUDARuntime> CUDARuntime;
+  std::unique_ptr<CGHLSLRuntime> HLSLRuntime;
   std::unique_ptr<CGDebugInfo> DebugInfo;
   std::unique_ptr<ObjCEntrypoints> ObjCData;
   llvm::MDNode *NoObjCARCExceptionsMetadata = nullptr;
@@ -512,6 +514,7 @@ class CodeGenModule : public CodeGenTypeCache {
   void createOpenCLRuntime();
   void createOpenMPRuntime();
   void createCUDARuntime();
+  void createHLSLRuntime();
 
   bool isTriviallyRecursive(const FunctionDecl *F);
   bool shouldEmitFunction(GlobalDecl GD);
@@ -610,6 +613,12 @@ class CodeGenModule : public CodeGenTypeCache {
     return *CUDARuntime;
   }
 
+  /// Return a reference to the configured HLSL runtime.
+  CGHLSLRuntime &getHLSLRuntime() {
+    assert(HLSLRuntime != nullptr);
+    return *HLSLRuntime;
+  }
+
   ObjCEntrypoints &getObjCEntrypoints() const {
     assert(ObjCData != nullptr);
     return *ObjCData;

diff  --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp
index b5f9e0a13a6fe..c5f38762a5d61 100644
--- a/clang/lib/Driver/ToolChains/Clang.cpp
+++ b/clang/lib/Driver/ToolChains/Clang.cpp
@@ -3468,6 +3468,15 @@ static void RenderOpenCLOptions(const ArgList &Args, ArgStringList &CmdArgs,
   }
 }
 
+static void RenderHLSLOptions(const ArgList &Args, ArgStringList &CmdArgs,
+                              types::ID InputType) {
+  const unsigned ForwardedArguments[] = {options::OPT_dxil_validator_version};
+
+  for (const auto &Arg : ForwardedArguments)
+    if (const auto *A = Args.getLastArg(Arg))
+      A->renderAsInput(Args, CmdArgs);
+}
+
 static void RenderARCMigrateToolOptions(const Driver &D, const ArgList &Args,
                                         ArgStringList &CmdArgs) {
   bool ARCMTEnabled = false;
@@ -6227,6 +6236,10 @@ void Clang::ConstructJob(Compilation &C, const JobAction &JA,
   // Forward -cl options to -cc1
   RenderOpenCLOptions(Args, CmdArgs, InputType);
 
+  // Forward hlsl options to -cc1
+  if (C.getDriver().IsDXCMode())
+    RenderHLSLOptions(Args, CmdArgs, InputType);
+
   if (IsHIP) {
     if (Args.hasFlag(options::OPT_fhip_new_launch_api,
                      options::OPT_fno_hip_new_launch_api, true))

diff  --git a/clang/lib/Driver/ToolChains/HLSL.cpp b/clang/lib/Driver/ToolChains/HLSL.cpp
index a2a2a3a7bf552..2822e062fcd5c 100644
--- a/clang/lib/Driver/ToolChains/HLSL.cpp
+++ b/clang/lib/Driver/ToolChains/HLSL.cpp
@@ -108,6 +108,29 @@ std::string tryParseProfile(StringRef Profile) {
     return "";
 }
 
+bool isLegalValidatorVersion(StringRef ValVersionStr, const Driver &D) {
+  VersionTuple Version;
+  if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
+      Version.getSubminor() || !Version.getMinor()) {
+    D.Diag(diag::err_drv_invalid_format_dxil_validator_version)
+        << ValVersionStr;
+    return false;
+  }
+
+  uint64_t Major = Version.getMajor();
+  uint64_t Minor = Version.getMinor().getValue();
+  if (Major == 0 && Minor != 0) {
+    D.Diag(diag::err_drv_invalid_empty_dxil_validator_version) << ValVersionStr;
+    return false;
+  }
+  VersionTuple MinVer(1, 0);
+  if (Version < MinVer) {
+    D.Diag(diag::err_drv_invalid_range_dxil_validator_version) << ValVersionStr;
+    return false;
+  }
+  return true;
+}
+
 } // namespace
 
 /// DirectX Toolchain
@@ -131,3 +154,30 @@ HLSLToolChain::ComputeEffectiveClangTriple(const ArgList &Args,
     return ToolChain::ComputeEffectiveClangTriple(Args, InputType);
   }
 }
+
+DerivedArgList *
+HLSLToolChain::TranslateArgs(const DerivedArgList &Args, StringRef BoundArch,
+                             Action::OffloadKind DeviceOffloadKind) const {
+  DerivedArgList *DAL = new DerivedArgList(Args.getBaseArgs());
+
+  const OptTable &Opts = getDriver().getOpts();
+
+  for (Arg *A : Args) {
+    if (A->getOption().getID() == options::OPT_dxil_validator_version) {
+      StringRef ValVerStr = A->getValue();
+      std::string ErrorMsg;
+      if (!isLegalValidatorVersion(ValVerStr, getDriver()))
+        continue;
+    }
+    DAL->append(A);
+  }
+  // Add default validator version if not set.
+  // TODO: remove this once read validator version from validator.
+  if (!DAL->hasArg(options::OPT_dxil_validator_version)) {
+    const StringRef DefaultValidatorVer = "1.7";
+    DAL->AddSeparateArg(nullptr,
+                        Opts.getOption(options::OPT_dxil_validator_version),
+                        DefaultValidatorVer);
+  }
+  return DAL;
+}

diff  --git a/clang/lib/Driver/ToolChains/HLSL.h b/clang/lib/Driver/ToolChains/HLSL.h
index 052003f53ae05..7774db3762dd3 100644
--- a/clang/lib/Driver/ToolChains/HLSL.h
+++ b/clang/lib/Driver/ToolChains/HLSL.h
@@ -26,6 +26,9 @@ class LLVM_LIBRARY_VISIBILITY HLSLToolChain : public ToolChain {
   }
   bool isPICDefaultForced() const override { return false; }
 
+  llvm::opt::DerivedArgList *
+  TranslateArgs(const llvm::opt::DerivedArgList &Args, StringRef BoundArch,
+                Action::OffloadKind DeviceOffloadKind) const override;
   std::string ComputeEffectiveClangTriple(const llvm::opt::ArgList &Args,
                                           types::ID InputType) const override;
 };

diff  --git a/clang/test/CodeGenHLSL/validator_version.hlsl b/clang/test/CodeGenHLSL/validator_version.hlsl
new file mode 100644
index 0000000000000..eee83bd9677be
--- /dev/null
+++ b/clang/test/CodeGenHLSL/validator_version.hlsl
@@ -0,0 +1,10 @@
+// RUN: %clang -cc1 -S -triple dxil-pc-shadermodel6.3-library -S -emit-llvm -xhlsl -validator-version 1.1 -o - %s | FileCheck %s
+
+// CHECK:!"dx.valver", ![[valver:[0-9]+]]}
+// CHECK:![[valver]] = !{i32 1, i32 1}
+
+float bar(float a, float b);
+
+float foo(float a, float b) {
+  return bar(a, b);
+}

diff  --git a/clang/unittests/Driver/ToolChainTest.cpp b/clang/unittests/Driver/ToolChainTest.cpp
index a0823f3ba123a..7abcb3ee0d975 100644
--- a/clang/unittests/Driver/ToolChainTest.cpp
+++ b/clang/unittests/Driver/ToolChainTest.cpp
@@ -367,29 +367,28 @@ TEST(GetDriverMode, PrefersLastDriverMode) {
   EXPECT_EQ(getDriverMode(Args[0], llvm::makeArrayRef(Args).slice(1)), "bar");
 }
 
+struct SimpleDiagnosticConsumer : public DiagnosticConsumer {
+  void HandleDiagnostic(DiagnosticsEngine::Level DiagLevel,
+                        const Diagnostic &Info) override {
+    if (DiagLevel == DiagnosticsEngine::Level::Error) {
+      Errors.emplace_back();
+      Info.FormatDiagnostic(Errors.back());
+    } else {
+      Msgs.emplace_back();
+      Info.FormatDiagnostic(Msgs.back());
+    }
+  }
+  void clear() override {
+    Msgs.clear();
+    Errors.clear();
+    DiagnosticConsumer::clear();
+  }
+  std::vector<SmallString<32>> Msgs;
+  std::vector<SmallString<32>> Errors;
+};
+
 TEST(DxcModeTest, TargetProfileValidation) {
   IntrusiveRefCntPtr<DiagnosticIDs> DiagID(new DiagnosticIDs());
-  struct SimpleDiagnosticConsumer : public DiagnosticConsumer {
-    void HandleDiagnostic(DiagnosticsEngine::Level DiagLevel,
-                          const Diagnostic &Info) override {
-      if (DiagLevel == DiagnosticsEngine::Level::Error) {
-        Errors.emplace_back();
-        Info.FormatDiagnostic(Errors.back());
-        Errors.back() += '\0';
-      } else {
-        Msgs.emplace_back();
-        Info.FormatDiagnostic(Msgs.back());
-        Msgs.back() += '\0';
-      }
-    }
-    void clear() override {
-      Msgs.clear();
-      Errors.clear();
-      DiagnosticConsumer::clear();
-    }
-    std::vector<SmallString<32>> Msgs;
-    std::vector<SmallString<32>> Errors;
-  };
 
   IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> InMemoryFileSystem(
       new llvm::vfs::InMemoryFileSystem);
@@ -474,7 +473,8 @@ TEST(DxcModeTest, TargetProfileValidation) {
   Triple = TC.ComputeEffectiveClangTriple(Args);
   EXPECT_STREQ(Triple.c_str(), "unknown-unknown-shadermodel");
   EXPECT_EQ(Diags.getNumErrors(), 1u);
-  EXPECT_STREQ(DiagConsumer->Errors.back().data(), "invalid profile : pss_6_1");
+  EXPECT_STREQ(DiagConsumer->Errors.back().c_str(),
+               "invalid profile : pss_6_1");
   Diags.Clear();
   DiagConsumer->clear();
 
@@ -483,7 +483,7 @@ TEST(DxcModeTest, TargetProfileValidation) {
   Triple = TC.ComputeEffectiveClangTriple(Args);
   EXPECT_STREQ(Triple.c_str(), "unknown-unknown-shadermodel");
   EXPECT_EQ(Diags.getNumErrors(), 2u);
-  EXPECT_STREQ(DiagConsumer->Errors.back().data(), "invalid profile : ps_6_x");
+  EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), "invalid profile : ps_6_x");
   Diags.Clear();
   DiagConsumer->clear();
 
@@ -492,7 +492,8 @@ TEST(DxcModeTest, TargetProfileValidation) {
   Triple = TC.ComputeEffectiveClangTriple(Args);
   EXPECT_STREQ(Triple.c_str(), "unknown-unknown-shadermodel");
   EXPECT_EQ(Diags.getNumErrors(), 3u);
-  EXPECT_STREQ(DiagConsumer->Errors.back().data(), "invalid profile : lib_6_1");
+  EXPECT_STREQ(DiagConsumer->Errors.back().c_str(),
+               "invalid profile : lib_6_1");
   Diags.Clear();
   DiagConsumer->clear();
 
@@ -501,7 +502,110 @@ TEST(DxcModeTest, TargetProfileValidation) {
   Triple = TC.ComputeEffectiveClangTriple(Args);
   EXPECT_STREQ(Triple.c_str(), "unknown-unknown-shadermodel");
   EXPECT_EQ(Diags.getNumErrors(), 4u);
-  EXPECT_STREQ(DiagConsumer->Errors.back().data(), "invalid profile : foo");
+  EXPECT_STREQ(DiagConsumer->Errors.back().c_str(), "invalid profile : foo");
+  Diags.Clear();
+  DiagConsumer->clear();
+}
+
+TEST(DxcModeTest, ValidatorVersionValidation) {
+  IntrusiveRefCntPtr<DiagnosticIDs> DiagID(new DiagnosticIDs());
+
+  IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> InMemoryFileSystem(
+      new llvm::vfs::InMemoryFileSystem);
+
+  InMemoryFileSystem->addFile("foo.hlsl", 0,
+                              llvm::MemoryBuffer::getMemBuffer("\n"));
+
+  auto *DiagConsumer = new SimpleDiagnosticConsumer;
+  IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts = new DiagnosticOptions();
+  DiagnosticsEngine Diags(DiagID, &*DiagOpts, DiagConsumer);
+  Driver TheDriver("/bin/clang", "", Diags, "", InMemoryFileSystem);
+  std::unique_ptr<Compilation> C(
+      TheDriver.BuildCompilation({"clang", "--driver-mode=dxc", "foo.hlsl"}));
+  EXPECT_TRUE(C);
+  EXPECT_TRUE(!C->containsError());
+
+  auto &TC = C->getDefaultToolChain();
+  bool ContainsError = false;
+  auto Args = TheDriver.ParseArgStrings({"-validator-version", "1.1"}, false,
+                                        ContainsError);
+  EXPECT_FALSE(ContainsError);
+  auto DAL = std::make_unique<llvm::opt::DerivedArgList>(Args);
+  for (auto *A : Args)
+    DAL->append(A);
+
+  auto *TranslatedArgs =
+      TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None);
+  EXPECT_NE(TranslatedArgs, nullptr);
+  if (TranslatedArgs) {
+    auto *A = TranslatedArgs->getLastArg(
+        clang::driver::options::OPT_dxil_validator_version);
+    EXPECT_NE(A, nullptr);
+    if (A)
+      EXPECT_STREQ(A->getValue(), "1.1");
+  }
+  EXPECT_EQ(Diags.getNumErrors(), 0);
+
+  // Invalid tests.
+  Args = TheDriver.ParseArgStrings({"-validator-version", "0.1"}, false,
+                                   ContainsError);
+  EXPECT_FALSE(ContainsError);
+  DAL = std::make_unique<llvm::opt::DerivedArgList>(Args);
+  for (auto *A : Args)
+    DAL->append(A);
+
+  TranslatedArgs = TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None);
+  EXPECT_EQ(Diags.getNumErrors(), 1);
+  EXPECT_STREQ(DiagConsumer->Errors.back().c_str(),
+               "invalid validator version : 0.1\nIf validator major version is "
+               "0, minor version must also be 0.");
+  Diags.Clear();
+  DiagConsumer->clear();
+
+  Args = TheDriver.ParseArgStrings({"-validator-version", "1"}, false,
+                                   ContainsError);
+  EXPECT_FALSE(ContainsError);
+  DAL = std::make_unique<llvm::opt::DerivedArgList>(Args);
+  for (auto *A : Args)
+    DAL->append(A);
+
+  TranslatedArgs = TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None);
+  EXPECT_EQ(Diags.getNumErrors(), 2);
+  EXPECT_STREQ(DiagConsumer->Errors.back().c_str(),
+               "invalid validator version : 1\nFormat of validator version is "
+               "\"<major>.<minor>\" (ex:\"1.4\").");
+  Diags.Clear();
+  DiagConsumer->clear();
+
+  Args = TheDriver.ParseArgStrings({"-validator-version", "-Tlib_6_7"}, false,
+                                   ContainsError);
+  EXPECT_FALSE(ContainsError);
+  DAL = std::make_unique<llvm::opt::DerivedArgList>(Args);
+  for (auto *A : Args)
+    DAL->append(A);
+
+  TranslatedArgs = TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None);
+  EXPECT_EQ(Diags.getNumErrors(), 3);
+  EXPECT_STREQ(
+      DiagConsumer->Errors.back().c_str(),
+      "invalid validator version : -Tlib_6_7\nFormat of validator version is "
+      "\"<major>.<minor>\" (ex:\"1.4\").");
+  Diags.Clear();
+  DiagConsumer->clear();
+
+  Args = TheDriver.ParseArgStrings({"-validator-version", "foo"}, false,
+                                   ContainsError);
+  EXPECT_FALSE(ContainsError);
+  DAL = std::make_unique<llvm::opt::DerivedArgList>(Args);
+  for (auto *A : Args)
+    DAL->append(A);
+
+  TranslatedArgs = TC.TranslateArgs(*DAL, "0", Action::OffloadKind::OFK_None);
+  EXPECT_EQ(Diags.getNumErrors(), 4);
+  EXPECT_STREQ(
+      DiagConsumer->Errors.back().c_str(),
+      "invalid validator version : foo\nFormat of validator version is "
+      "\"<major>.<minor>\" (ex:\"1.4\").");
   Diags.Clear();
   DiagConsumer->clear();
 }


        


More information about the cfe-commits mailing list