[llvm] [DirectX] remove string function attribute DXIL not allowed (PR #90778)

Xiang Li via llvm-commits llvm-commits at lists.llvm.org
Tue May 7 07:29:56 PDT 2024


https://github.com/python3kgae updated https://github.com/llvm/llvm-project/pull/90778

>From af945856d226ac6cfcf695d3836c029db551134f Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Wed, 1 May 2024 16:58:44 -0400
Subject: [PATCH 1/6] [DirectX] remove module flags and function attribute DXIL
 not allowed

Remove module flags other than "Dwarf Version" and "Debug Info Version".
Remove string function attribute other than "waveops-include-helper-lanes" and "fp32-denorm-mode".

Move DXILPrepareModulePass after DXILTranslateMetadataPass since
DXILTranslateMetadataPass needs to use attribute like hlsl.numthreads.

Fixes #90773
---
 llvm/lib/Target/DirectX/DXILPrepare.cpp       | 65 ++++++++++++++++++-
 .../Target/DirectX/DirectXTargetMachine.cpp   |  2 +-
 .../DirectX/Metadata/shaderModel-cs.ll        |  4 ++
 llvm/test/CodeGen/DirectX/dxil_ver.ll         |  5 ++
 llvm/test/tools/dxil-dis/attribute-filter.ll  |  4 +-
 5 files changed, 76 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILPrepare.cpp b/llvm/lib/Target/DirectX/DXILPrepare.cpp
index 026911946b47f..e3a7ddee083b0 100644
--- a/llvm/lib/Target/DirectX/DXILPrepare.cpp
+++ b/llvm/lib/Target/DirectX/DXILPrepare.cpp
@@ -11,10 +11,13 @@
 /// Language (DXIL).
 //===----------------------------------------------------------------------===//
 
+#include "DXILResourceAnalysis.h"
+#include "DXILShaderFlags.h"
 #include "DirectX.h"
 #include "DirectXIRPasses/PointerTypeAnalysis.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringSet.h"
 #include "llvm/CodeGen/Passes.h"
 #include "llvm/IR/AttributeMask.h"
 #include "llvm/IR/IRBuilder.h"
@@ -80,6 +83,60 @@ constexpr bool isValidForDXIL(Attribute::AttrKind Attr) {
                       Attr);
 }
 
+static void collectDeadStringAttrs(AttributeMask &DeadAttrs, AttributeSet &&AS,
+                                   StringSet<> LiveKeys) {
+  for (auto &Attr : AS) {
+    if (!Attr.isStringAttribute())
+      continue;
+    StringRef Key = Attr.getKindAsString();
+    if (LiveKeys.contains(Key))
+      continue;
+    DeadAttrs.addAttribute(Key);
+  }
+}
+
+static void removeStringFunctionAttributes(Function &F) {
+  AttributeList Attrs = F.getAttributes();
+  StringSet<> LiveKeys = {"waveops-include-helper-lanes"
+                          "fp32-denorm-mode"};
+  // Collect DeadKeys in FnAttrs.
+  AttributeMask DeadAttrs;
+  collectDeadStringAttrs(DeadAttrs, Attrs.getFnAttrs(), LiveKeys);
+  collectDeadStringAttrs(DeadAttrs, Attrs.getRetAttrs(), LiveKeys);
+
+  F.removeFnAttrs(DeadAttrs);
+  F.removeRetAttrs(DeadAttrs);
+}
+
+static void cleanModuleFlags(Module &M) {
+  NamedMDNode *MDFlags = M.getModuleFlagsMetadata();
+  if (!MDFlags)
+    return;
+
+  StringSet<> LiveKeys = {"Dwarf Version", "Debug Info Version"};
+
+  SmallVector<llvm::Module::ModuleFlagEntry> FlagEntries;
+  M.getModuleFlagsMetadata(FlagEntries);
+
+  bool HasDeadKey = false;
+  for (auto &Flag : FlagEntries) {
+    if (!LiveKeys.count(Flag.Key->getString())) {
+      HasDeadKey = true;
+      break;
+    }
+  }
+  if (!HasDeadKey)
+    return;
+
+  MDFlags->eraseFromParent();
+
+  for (auto &Flag : FlagEntries) {
+    if (!LiveKeys.count(Flag.Key->getString()))
+      continue;
+    M.addModuleFlag(Flag.Behavior, Flag.Key->getString(), Flag.Val);
+  }
+}
+
 class DXILPrepareModule : public ModulePass {
 
   static Value *maybeGenerateBitcast(IRBuilder<> &Builder,
@@ -113,6 +170,7 @@ class DXILPrepareModule : public ModulePass {
     for (auto &F : M.functions()) {
       F.removeFnAttrs(AttrMask);
       F.removeRetAttrs(AttrMask);
+      removeStringFunctionAttributes(F);
       for (size_t Idx = 0, End = F.arg_size(); Idx < End; ++Idx)
         F.removeParamAttrs(Idx, AttrMask);
 
@@ -168,11 +226,16 @@ class DXILPrepareModule : public ModulePass {
         }
       }
     }
+    // Remove flags not in llvm3.7.
+    cleanModuleFlags(M);
     return true;
   }
 
   DXILPrepareModule() : ModulePass(ID) {}
-
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addPreserved<ShaderFlagsAnalysisWrapper>();
+    AU.addPreserved<DXILResourceWrapper>();
+  }
   static char ID; // Pass identification.
 };
 char DXILPrepareModule::ID = 0;
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index bebca0675522f..c853393e4282a 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -79,8 +79,8 @@ class DirectXPassConfig : public TargetPassConfig {
   void addCodeGenPrepare() override {
     addPass(createDXILIntrinsicExpansionLegacyPass());
     addPass(createDXILOpLoweringLegacyPass());
-    addPass(createDXILPrepareModulePass());
     addPass(createDXILTranslateMetadataPass());
+    addPass(createDXILPrepareModulePass());
   }
 };
 
diff --git a/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs.ll b/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs.ll
index be4b46f22ef25..f617ad8d299ef 100644
--- a/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs.ll
+++ b/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs.ll
@@ -1,4 +1,6 @@
 ; RUN: opt -S -dxil-metadata-emit %s | FileCheck %s
+; RUN: opt -S -dxil-prepare  %s | FileCheck %s  --check-prefix=REMOVE_EXTRA_ATTRIBUTE
+
 target triple = "dxil-pc-shadermodel6.6-compute"
 
 ; CHECK: !dx.shaderModel = !{![[SM:[0-9]+]]}
@@ -9,4 +11,6 @@ entry:
   ret void
 }
 
+; Make sure extra attribute like hlsl.numthreads are removed.
+; REMOVE_EXTRA_ATTRIBUTE:attributes #0 = { noinline nounwind } 
 attributes #0 = { noinline nounwind "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute" }
diff --git a/llvm/test/CodeGen/DirectX/dxil_ver.ll b/llvm/test/CodeGen/DirectX/dxil_ver.ll
index e9923a3abce02..c893a912f92f1 100644
--- a/llvm/test/CodeGen/DirectX/dxil_ver.ll
+++ b/llvm/test/CodeGen/DirectX/dxil_ver.ll
@@ -1,4 +1,5 @@
 ; RUN: opt -S -dxil-metadata-emit < %s | FileCheck %s
+; RUN: opt -S -dxil-prepare  %s | FileCheck %s  --check-prefix=REMOVE_EXTRA_MODULE_FLAG
 target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
 target triple = "dxil-pc-shadermodel6.3-library"
 
@@ -11,6 +12,10 @@ target triple = "dxil-pc-shadermodel6.3-library"
 ; Make sure wchar_size still exist.
 ; CHECK-DAG:!{i32 1, !"wchar_size", i32 4}
 
+; Make sure no !llvm.module.flags left.
+; REMOVE_EXTRA_MODULE_FLAG: target triple = "dxil-pc-shadermodel6.3-library"
+; REMOVE_EXTRA_MODULE_FLAG-NOT: !llvm.module.flags
+
 !llvm.module.flags = !{!0}
 !dx.valver = !{!1}
 !llvm.ident = !{!2}
diff --git a/llvm/test/tools/dxil-dis/attribute-filter.ll b/llvm/test/tools/dxil-dis/attribute-filter.ll
index 432a5a1b71018..03c7c36c31258 100644
--- a/llvm/test/tools/dxil-dis/attribute-filter.ll
+++ b/llvm/test/tools/dxil-dis/attribute-filter.ll
@@ -19,8 +19,8 @@ define float @fma2(float %0, float %1, float %2) #1 {
   ret float %5
 }
 
-; CHECK: attributes #0 = { nounwind readnone "disable-tail-calls"="false" }
+; CHECK: attributes #0 = { nounwind readnone }
 attributes #0 = { norecurse nounwind readnone willreturn "disable-tail-calls"="false" }
 
-; CHECK: attributes #1 = { readnone "disable-tail-calls"="false" }
+; CHECK: attributes #1 = { readnone }
 attributes #1 = { norecurse memory(none) willreturn "disable-tail-calls"="false" }

>From 463b280c75198cab5c127bc47f45df5f41049a44 Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Fri, 3 May 2024 16:36:47 -0400
Subject: [PATCH 2/6] Remove the module flag part.

---
 llvm/lib/Target/DirectX/DXILPrepare.cpp | 31 -------------------------
 llvm/test/CodeGen/DirectX/dxil_ver.ll   |  5 ----
 2 files changed, 36 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILPrepare.cpp b/llvm/lib/Target/DirectX/DXILPrepare.cpp
index e3a7ddee083b0..a6680a4a40089 100644
--- a/llvm/lib/Target/DirectX/DXILPrepare.cpp
+++ b/llvm/lib/Target/DirectX/DXILPrepare.cpp
@@ -108,35 +108,6 @@ static void removeStringFunctionAttributes(Function &F) {
   F.removeRetAttrs(DeadAttrs);
 }
 
-static void cleanModuleFlags(Module &M) {
-  NamedMDNode *MDFlags = M.getModuleFlagsMetadata();
-  if (!MDFlags)
-    return;
-
-  StringSet<> LiveKeys = {"Dwarf Version", "Debug Info Version"};
-
-  SmallVector<llvm::Module::ModuleFlagEntry> FlagEntries;
-  M.getModuleFlagsMetadata(FlagEntries);
-
-  bool HasDeadKey = false;
-  for (auto &Flag : FlagEntries) {
-    if (!LiveKeys.count(Flag.Key->getString())) {
-      HasDeadKey = true;
-      break;
-    }
-  }
-  if (!HasDeadKey)
-    return;
-
-  MDFlags->eraseFromParent();
-
-  for (auto &Flag : FlagEntries) {
-    if (!LiveKeys.count(Flag.Key->getString()))
-      continue;
-    M.addModuleFlag(Flag.Behavior, Flag.Key->getString(), Flag.Val);
-  }
-}
-
 class DXILPrepareModule : public ModulePass {
 
   static Value *maybeGenerateBitcast(IRBuilder<> &Builder,
@@ -226,8 +197,6 @@ class DXILPrepareModule : public ModulePass {
         }
       }
     }
-    // Remove flags not in llvm3.7.
-    cleanModuleFlags(M);
     return true;
   }
 
diff --git a/llvm/test/CodeGen/DirectX/dxil_ver.ll b/llvm/test/CodeGen/DirectX/dxil_ver.ll
index c893a912f92f1..e9923a3abce02 100644
--- a/llvm/test/CodeGen/DirectX/dxil_ver.ll
+++ b/llvm/test/CodeGen/DirectX/dxil_ver.ll
@@ -1,5 +1,4 @@
 ; RUN: opt -S -dxil-metadata-emit < %s | FileCheck %s
-; RUN: opt -S -dxil-prepare  %s | FileCheck %s  --check-prefix=REMOVE_EXTRA_MODULE_FLAG
 target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
 target triple = "dxil-pc-shadermodel6.3-library"
 
@@ -12,10 +11,6 @@ target triple = "dxil-pc-shadermodel6.3-library"
 ; Make sure wchar_size still exist.
 ; CHECK-DAG:!{i32 1, !"wchar_size", i32 4}
 
-; Make sure no !llvm.module.flags left.
-; REMOVE_EXTRA_MODULE_FLAG: target triple = "dxil-pc-shadermodel6.3-library"
-; REMOVE_EXTRA_MODULE_FLAG-NOT: !llvm.module.flags
-
 !llvm.module.flags = !{!0}
 !dx.valver = !{!1}
 !llvm.ident = !{!2}

>From fb3e1e7a1b6d7bb57d626556c95fa6e745c37b94 Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Fri, 3 May 2024 18:18:38 -0400
Subject: [PATCH 3/6] Test supprted atttribute not be removed.

---
 llvm/lib/Target/DirectX/DXILPrepare.cpp      | 6 +++---
 llvm/test/tools/dxil-dis/attribute-filter.ll | 8 ++++----
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILPrepare.cpp b/llvm/lib/Target/DirectX/DXILPrepare.cpp
index a6680a4a40089..9b5e9ebad0914 100644
--- a/llvm/lib/Target/DirectX/DXILPrepare.cpp
+++ b/llvm/lib/Target/DirectX/DXILPrepare.cpp
@@ -84,7 +84,7 @@ constexpr bool isValidForDXIL(Attribute::AttrKind Attr) {
 }
 
 static void collectDeadStringAttrs(AttributeMask &DeadAttrs, AttributeSet &&AS,
-                                   StringSet<> LiveKeys) {
+                                   const StringSet<> &LiveKeys) {
   for (auto &Attr : AS) {
     if (!Attr.isStringAttribute())
       continue;
@@ -97,8 +97,8 @@ static void collectDeadStringAttrs(AttributeMask &DeadAttrs, AttributeSet &&AS,
 
 static void removeStringFunctionAttributes(Function &F) {
   AttributeList Attrs = F.getAttributes();
-  StringSet<> LiveKeys = {"waveops-include-helper-lanes"
-                          "fp32-denorm-mode"};
+  const StringSet<> LiveKeys = {"waveops-include-helper-lanes",
+                                "fp32-denorm-mode"};
   // Collect DeadKeys in FnAttrs.
   AttributeMask DeadAttrs;
   collectDeadStringAttrs(DeadAttrs, Attrs.getFnAttrs(), LiveKeys);
diff --git a/llvm/test/tools/dxil-dis/attribute-filter.ll b/llvm/test/tools/dxil-dis/attribute-filter.ll
index 03c7c36c31258..27590e10d79b5 100644
--- a/llvm/test/tools/dxil-dis/attribute-filter.ll
+++ b/llvm/test/tools/dxil-dis/attribute-filter.ll
@@ -19,8 +19,8 @@ define float @fma2(float %0, float %1, float %2) #1 {
   ret float %5
 }
 
-; CHECK: attributes #0 = { nounwind readnone }
-attributes #0 = { norecurse nounwind readnone willreturn "disable-tail-calls"="false" }
+; CHECK: attributes #0 = { nounwind readnone "fp32-denorm-mode"="any" "waveops-include-helper-lanes" }
+attributes #0 = { norecurse nounwind readnone willreturn "disable-tail-calls"="false" "waveops-include-helper-lanes" "fp32-denorm-mode"="any" }
 
-; CHECK: attributes #1 = { readnone }
-attributes #1 = { norecurse memory(none) willreturn "disable-tail-calls"="false" }
+; CHECK: attributes #1 = { readnone "fp32-denorm-mode"="ftz" "waveops-include-helper-lanes" }
+attributes #1 = { norecurse memory(none) willreturn "disable-tail-calls"="false" "waveops-include-helper-lanes" "fp32-denorm-mode"="ftz" }

>From 77d48a658ceb2a2c2ff25d0a76f3a19eb83a3097 Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Mon, 6 May 2024 15:21:02 -0400
Subject: [PATCH 4/6] Skip removeStringFunctionAttributes when validation
 version is 0.0.

---
 llvm/lib/Target/DirectX/DXILMetadata.cpp         |  9 +++++++++
 llvm/lib/Target/DirectX/DXILMetadata.h           |  1 +
 llvm/lib/Target/DirectX/DXILPrepare.cpp          | 14 +++++++++++++-
 .../Metadata/shaderModel-cs-val-ver-0.0.ll       | 16 ++++++++++++++++
 4 files changed, 39 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs-val-ver-0.0.ll

diff --git a/llvm/lib/Target/DirectX/DXILMetadata.cpp b/llvm/lib/Target/DirectX/DXILMetadata.cpp
index 2d94490a7f24c..03758dc76e7eb 100644
--- a/llvm/lib/Target/DirectX/DXILMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILMetadata.cpp
@@ -40,6 +40,15 @@ void ValidatorVersionMD::update(VersionTuple ValidatorVer) {
 
 bool ValidatorVersionMD::isEmpty() { return Entry->getNumOperands() == 0; }
 
+VersionTuple ValidatorVersionMD::getAsVersionTuple() {
+  if (isEmpty())
+    return VersionTuple(1, 0);
+  auto *ValVerMD = cast<MDNode>(Entry->getOperand(0));
+  auto *MajorMD = mdconst::extract<ConstantInt>(ValVerMD->getOperand(0));
+  auto *MinorMD = mdconst::extract<ConstantInt>(ValVerMD->getOperand(1));
+  return VersionTuple(MajorMD->getZExtValue(), MinorMD->getZExtValue());
+}
+
 static StringRef getShortShaderStage(Triple::EnvironmentType Env) {
   switch (Env) {
   case Triple::Pixel:
diff --git a/llvm/lib/Target/DirectX/DXILMetadata.h b/llvm/lib/Target/DirectX/DXILMetadata.h
index 2f5d7d9fe7683..cd9f4c83fbd0f 100644
--- a/llvm/lib/Target/DirectX/DXILMetadata.h
+++ b/llvm/lib/Target/DirectX/DXILMetadata.h
@@ -30,6 +30,7 @@ class ValidatorVersionMD {
   void update(VersionTuple ValidatorVer);
 
   bool isEmpty();
+  VersionTuple getAsVersionTuple();
 };
 
 void createShaderModelMD(Module &M);
diff --git a/llvm/lib/Target/DirectX/DXILPrepare.cpp b/llvm/lib/Target/DirectX/DXILPrepare.cpp
index 9b5e9ebad0914..e1352fa2cf13d 100644
--- a/llvm/lib/Target/DirectX/DXILPrepare.cpp
+++ b/llvm/lib/Target/DirectX/DXILPrepare.cpp
@@ -11,6 +11,7 @@
 /// Language (DXIL).
 //===----------------------------------------------------------------------===//
 
+#include "DXILMetadata.h"
 #include "DXILResourceAnalysis.h"
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
@@ -26,6 +27,7 @@
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/Compiler.h"
+#include "llvm/Support/VersionTuple.h"
 
 #define DEBUG_TYPE "dxil-prepare"
 
@@ -138,10 +140,20 @@ class DXILPrepareModule : public ModulePass {
       if (!isValidForDXIL(I))
         AttrMask.addAttribute(I);
     }
+
+    
+    dxil::ValidatorVersionMD ValVerMD(M);
+    VersionTuple ValVer = ValVerMD.getAsVersionTuple();
+    bool SkipValidation = ValVer.getMajor() == 0 && ValVer.getMinor() == 0;
+
     for (auto &F : M.functions()) {
       F.removeFnAttrs(AttrMask);
       F.removeRetAttrs(AttrMask);
-      removeStringFunctionAttributes(F);
+      // Only remove string attributes if we are not skipping validation.
+      // This will reserve the experimental attributes when validation version
+      // is 0.0 for experiment mode.
+      if (!SkipValidation)
+        removeStringFunctionAttributes(F);
       for (size_t Idx = 0, End = F.arg_size(); Idx < End; ++Idx)
         F.removeParamAttrs(Idx, AttrMask);
 
diff --git a/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs-val-ver-0.0.ll b/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs-val-ver-0.0.ll
new file mode 100644
index 0000000000000..89590845c68f6
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs-val-ver-0.0.ll
@@ -0,0 +1,16 @@
+; RUN: opt -S -dxil-prepare  %s | FileCheck %s 
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+define void @entry() #0 {
+entry:
+  ret void
+}
+
+; Make sure extra attribute like hlsl.numthreads are left when validation version is 0.0.
+; CHECK:attributes #0 = { noinline nounwind "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute" } 
+attributes #0 = { noinline nounwind "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute" }
+
+!dx.valver = !{!0}
+
+!0 = !{i32 0, i32 0}

>From cdb8ca01e22938466442a00d61d0b5a6c7485cd3 Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Mon, 6 May 2024 16:04:04 -0400
Subject: [PATCH 5/6] clang-format fix.

---
 llvm/lib/Target/DirectX/DXILPrepare.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/Target/DirectX/DXILPrepare.cpp b/llvm/lib/Target/DirectX/DXILPrepare.cpp
index e1352fa2cf13d..de9269b85ff2c 100644
--- a/llvm/lib/Target/DirectX/DXILPrepare.cpp
+++ b/llvm/lib/Target/DirectX/DXILPrepare.cpp
@@ -141,7 +141,6 @@ class DXILPrepareModule : public ModulePass {
         AttrMask.addAttribute(I);
     }
 
-    
     dxil::ValidatorVersionMD ValVerMD(M);
     VersionTuple ValVer = ValVerMD.getAsVersionTuple();
     bool SkipValidation = ValVer.getMajor() == 0 && ValVer.getMinor() == 0;

>From 33a3fc1d4c12e693fca8b3ee754f8fa50bd9e83e Mon Sep 17 00:00:00 2001
From: Xiang Li <python3kgae at outlook.com>
Date: Tue, 7 May 2024 10:29:07 -0400
Subject: [PATCH 6/6] Only keep experimental attribute when validation version
 is 0.0.

---
 llvm/lib/Target/DirectX/DXILPrepare.cpp         | 17 +++++++++++------
 .../Metadata/shaderModel-cs-val-ver-0.0.ll      |  6 +++---
 .../CodeGen/DirectX/Metadata/shaderModel-cs.ll  |  3 ++-
 3 files changed, 16 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILPrepare.cpp b/llvm/lib/Target/DirectX/DXILPrepare.cpp
index de9269b85ff2c..24be644d9fc0e 100644
--- a/llvm/lib/Target/DirectX/DXILPrepare.cpp
+++ b/llvm/lib/Target/DirectX/DXILPrepare.cpp
@@ -86,25 +86,31 @@ constexpr bool isValidForDXIL(Attribute::AttrKind Attr) {
 }
 
 static void collectDeadStringAttrs(AttributeMask &DeadAttrs, AttributeSet &&AS,
-                                   const StringSet<> &LiveKeys) {
+                                   const StringSet<> &LiveKeys,
+                                   bool AllowExperimental) {
   for (auto &Attr : AS) {
     if (!Attr.isStringAttribute())
       continue;
     StringRef Key = Attr.getKindAsString();
     if (LiveKeys.contains(Key))
       continue;
+    if (AllowExperimental && Key.starts_with("exp-"))
+      continue;
     DeadAttrs.addAttribute(Key);
   }
 }
 
-static void removeStringFunctionAttributes(Function &F) {
+static void removeStringFunctionAttributes(Function &F,
+                                           bool AllowExperimental) {
   AttributeList Attrs = F.getAttributes();
   const StringSet<> LiveKeys = {"waveops-include-helper-lanes",
                                 "fp32-denorm-mode"};
   // Collect DeadKeys in FnAttrs.
   AttributeMask DeadAttrs;
-  collectDeadStringAttrs(DeadAttrs, Attrs.getFnAttrs(), LiveKeys);
-  collectDeadStringAttrs(DeadAttrs, Attrs.getRetAttrs(), LiveKeys);
+  collectDeadStringAttrs(DeadAttrs, Attrs.getFnAttrs(), LiveKeys,
+                         AllowExperimental);
+  collectDeadStringAttrs(DeadAttrs, Attrs.getRetAttrs(), LiveKeys,
+                         AllowExperimental);
 
   F.removeFnAttrs(DeadAttrs);
   F.removeRetAttrs(DeadAttrs);
@@ -151,8 +157,7 @@ class DXILPrepareModule : public ModulePass {
       // Only remove string attributes if we are not skipping validation.
       // This will reserve the experimental attributes when validation version
       // is 0.0 for experiment mode.
-      if (!SkipValidation)
-        removeStringFunctionAttributes(F);
+      removeStringFunctionAttributes(F, SkipValidation);
       for (size_t Idx = 0, End = F.arg_size(); Idx < End; ++Idx)
         F.removeParamAttrs(Idx, AttrMask);
 
diff --git a/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs-val-ver-0.0.ll b/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs-val-ver-0.0.ll
index 89590845c68f6..a85dc43ac2f6c 100644
--- a/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs-val-ver-0.0.ll
+++ b/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs-val-ver-0.0.ll
@@ -7,9 +7,9 @@ entry:
   ret void
 }
 
-; Make sure extra attribute like hlsl.numthreads are left when validation version is 0.0.
-; CHECK:attributes #0 = { noinline nounwind "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute" } 
-attributes #0 = { noinline nounwind "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute" }
+; Make sure experimental attribute is left when validation version is 0.0.
+; CHECK:attributes #0 = { noinline nounwind "exp-shader"="cs" } 
+attributes #0 = { noinline nounwind "exp-shader"="cs" "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute" }
 
 !dx.valver = !{!0}
 
diff --git a/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs.ll b/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs.ll
index f617ad8d299ef..343f190d994f0 100644
--- a/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs.ll
+++ b/llvm/test/CodeGen/DirectX/Metadata/shaderModel-cs.ll
@@ -12,5 +12,6 @@ entry:
 }
 
 ; Make sure extra attribute like hlsl.numthreads are removed.
+; And experimental attribute is removed when validator version is not 0.0.
 ; REMOVE_EXTRA_ATTRIBUTE:attributes #0 = { noinline nounwind } 
-attributes #0 = { noinline nounwind "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute" }
+attributes #0 = { noinline nounwind "exp-shader"="cs" "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute" }



More information about the llvm-commits mailing list