[llvm] [DirectX] TypedUAVLoadAdditionalFormats shader flag (PR #120477)

Justin Bogner via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 18 12:41:14 PST 2024


https://github.com/bogner created https://github.com/llvm/llvm-project/pull/120477

Set the TypedUAVLoadAddtionalFormats flag if the shader contains a load from a multicomponent UAV.

Fixes #114557

>From 6d07543b2da6304e2d3746792f51020d7b831b16 Mon Sep 17 00:00:00 2001
From: Justin Bogner <mail at justinbogner.com>
Date: Thu, 21 Nov 2024 15:52:02 -0800
Subject: [PATCH] [DirectX] TypedUAVLoadAdditionalFormats shader flag

Set the TypedUAVLoadAddtionalFormats flag if the shader contains a load
from a multicomponent UAV.

Fixes #114557
---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp   | 49 ++++++++++++++++---
 llvm/lib/Target/DirectX/DXILShaderFlags.h     |  7 ++-
 .../typed-uav-load-additional-formats.ll      | 44 +++++++++++++++++
 3 files changed, 88 insertions(+), 12 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/ShaderFlags/typed-uav-load-additional-formats.ll

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 2db4c1729c39fc..1e88963345763f 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -14,16 +14,21 @@
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
 #include "llvm/IR/Module.h"
+#include "llvm/InitializePasses.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
 using namespace llvm::dxil;
 
-static void updateFunctionFlags(ComputedShaderFlags &CSF,
-                                const Instruction &I) {
+static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
+                                DXILResourceTypeMap &DRTM) {
   if (!CSF.Doubles)
     CSF.Doubles = I.getType()->isDoubleTy();
 
@@ -44,9 +49,23 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF,
       break;
     }
   }
+
+  if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
+    switch (II->getIntrinsicID()) {
+    default:
+      break;
+    case Intrinsic::dx_typedBufferLoad: {
+      dxil::ResourceTypeInfo &RTI =
+          DRTM[cast<TargetExtType>(II->getArgOperand(0)->getType())];
+      if (RTI.isTyped())
+        CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1;
+    }
+    }
+  }
 }
 
-void ModuleShaderFlags::initialize(const Module &M) {
+void ModuleShaderFlags::initialize(const Module &M, DXILResourceTypeMap &DRTM) {
+
   // Collect shader flags for each of the functions
   for (const auto &F : M.getFunctionList()) {
     if (F.isDeclaration()) {
@@ -57,7 +76,7 @@ void ModuleShaderFlags::initialize(const Module &M) {
     ComputedShaderFlags CSF;
     for (const auto &BB : F)
       for (const auto &I : BB)
-        updateFunctionFlags(CSF, I);
+        updateFunctionFlags(CSF, I, DRTM);
     // Insert shader flag mask for function F
     FunctionFlags.push_back({&F, CSF});
     // Update combined shader flags mask
@@ -104,8 +123,11 @@ AnalysisKey ShaderFlagsAnalysis::Key;
 
 ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M,
                                            ModuleAnalysisManager &AM) {
+  DXILResourceTypeMap &DRTM = AM.getResult<DXILResourceTypeAnalysis>(M);
+
   ModuleShaderFlags MSFI;
-  MSFI.initialize(M);
+  MSFI.initialize(M, DRTM);
+
   return MSFI;
 }
 
@@ -132,11 +154,22 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
 
 bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
-  MSFI.initialize(M);
+  DXILResourceTypeMap &DRTM =
+      getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
+
+  MSFI.initialize(M, DRTM);
   return false;
 }
 
+void ShaderFlagsAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.setPreservesAll();
+  AU.addRequiredTransitive<DXILResourceTypeWrapperPass>();
+}
+
 char ShaderFlagsAnalysisWrapper::ID = 0;
 
-INITIALIZE_PASS(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
-                "DXIL Shader Flag Analysis", true, true)
+INITIALIZE_PASS_BEGIN(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
+                      "DXIL Shader Flag Analysis", true, true)
+INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
+INITIALIZE_PASS_END(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
+                    "DXIL Shader Flag Analysis", true, true)
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 2d60137f8b191c..67ddab39d0f349 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -26,6 +26,7 @@
 namespace llvm {
 class Module;
 class GlobalVariable;
+class DXILResourceTypeMap;
 
 namespace dxil {
 
@@ -84,7 +85,7 @@ struct ComputedShaderFlags {
 };
 
 struct ModuleShaderFlags {
-  void initialize(const Module &);
+  void initialize(const Module &, DXILResourceTypeMap &DRTM);
   const ComputedShaderFlags &getFunctionFlags(const Function *) const;
   const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }
 
@@ -135,9 +136,7 @@ class ShaderFlagsAnalysisWrapper : public ModulePass {
 
   bool runOnModule(Module &M) override;
 
-  void getAnalysisUsage(AnalysisUsage &AU) const override {
-    AU.setPreservesAll();
-  }
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
 };
 
 } // namespace dxil
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/typed-uav-load-additional-formats.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/typed-uav-load-additional-formats.ll
new file mode 100644
index 00000000000000..b6947393c4533d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/typed-uav-load-additional-formats.ll
@@ -0,0 +1,44 @@
+; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
+; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=CHECK-OBJ
+
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; CHECK-OBJ: - Name: SFI0
+; CHECK-OBJ:   Flags:
+; CHECK-OBJ:     TypedUAVLoadAdditionalFormats: true
+
+; CHECK:      Combined Shader Flags for Module
+; CHECK-NEXT: Shader Flags Value: 0x00002000
+
+; CHECK: Note: shader requires additional functionality:
+; CHECK:       Typed UAV Load Additional Formats
+
+; CHECK: Function multicomponent : 0x00002000
+define <4 x float> @multicomponent() #0 {
+  %res = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+      @llvm.dx.handle.fromBinding(i32 0, i32 0, i32 1, i32 0, i1 false)
+  %val = call <4 x float> @llvm.dx.typedBufferLoad(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %res, i32 0)
+  ret <4 x float> %val
+}
+
+; CHECK: Function onecomponent : 0x00000000
+define float @onecomponent() #0 {
+  %res = call target("dx.TypedBuffer", float, 1, 0, 0)
+      @llvm.dx.handle.fromBinding(i32 0, i32 0, i32 1, i32 0, i1 false)
+  %val = call float @llvm.dx.typedBufferLoad(
+      target("dx.TypedBuffer", float, 1, 0, 0) %res, i32 0)
+  ret float %val
+}
+
+; CHECK: Function noload : 0x00000000
+define void @noload(<4 x float> %val) #0 {
+  %res = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+      @llvm.dx.handle.fromBinding(i32 0, i32 0, i32 1, i32 0, i1 false)
+  call void @llvm.dx.typedBufferStore(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %res, i32 0,
+      <4 x float> %val)
+  ret void
+}
+
+attributes #0 = { convergent norecurse nounwind "hlsl.export"}



More information about the llvm-commits mailing list