[clang] [llvm] [DirectX] emmit `dx.precise` metadata when using `-Gis` flag (PR #192526)

via cfe-commits cfe-commits at lists.llvm.org
Fri Apr 17 11:17:32 PDT 2026


https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/192526

>From 7e74f5fa910771a7e90dbdb7816d97de2e34302d Mon Sep 17 00:00:00 2001
From: Joao Saffran <joaosaffranllvm at gmail.com>
Date: Thu, 16 Apr 2026 11:55:02 -0700
Subject: [PATCH 1/5] add precise support to intrinsics

---
 clang/lib/Driver/ToolChains/HLSL.cpp      |  3 ++
 llvm/lib/Target/DirectX/DXIL.td           | 33 ++++++++++++++++++++
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp | 37 +++++++++++++++++++++++
 llvm/utils/TableGen/DXILEmitter.cpp       |  4 ++-
 4 files changed, 76 insertions(+), 1 deletion(-)

diff --git a/clang/lib/Driver/ToolChains/HLSL.cpp b/clang/lib/Driver/ToolChains/HLSL.cpp
index 834b8acc78734..163a66a9703ee 100644
--- a/clang/lib/Driver/ToolChains/HLSL.cpp
+++ b/clang/lib/Driver/ToolChains/HLSL.cpp
@@ -512,6 +512,9 @@ HLSLToolChain::TranslateArgs(const DerivedArgList &Args, StringRef BoundArch,
       // Translate -Gis into -ffp_model_EQ=strict
       DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_ffp_model_EQ),
                           "strict");
+
+      DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_mllvm),
+                          StringRef("-enable-precise"));
       A->claim();
       continue;
     }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 0a1e0114aa3bb..8633d1b9f43d1 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -405,6 +405,8 @@ class DXILOp<int opcode, DXILOpClass opclass> {
 
   // Versioned attributes of operation
   list<Attributes> attributes = [];
+
+  bit precise = 0;
 }
 
 // Concrete definitions of DXIL Operations
@@ -419,6 +421,7 @@ def Abs : DXILOp<6, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Saturate : DXILOp<7, unary> {
@@ -430,6 +433,7 @@ def Saturate : DXILOp<7, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def IsNaN : DXILOp<8, isSpecialFloat> {
@@ -478,6 +482,7 @@ def Cos : DXILOp<12, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Sin : DXILOp<13, unary> {
@@ -488,6 +493,7 @@ def Sin : DXILOp<13, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Tan : DXILOp<14, unary> {
@@ -498,6 +504,7 @@ def Tan : DXILOp<14, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def ACos : DXILOp<15, unary> {
@@ -508,6 +515,7 @@ def ACos : DXILOp<15, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def ASin : DXILOp<16, unary> {
@@ -518,6 +526,7 @@ def ASin : DXILOp<16, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def ATan : DXILOp<17, unary> {
@@ -528,6 +537,7 @@ def ATan : DXILOp<17, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def HCos : DXILOp<18, unary> {
@@ -538,6 +548,7 @@ def HCos : DXILOp<18, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def HSin : DXILOp<19, unary> {
@@ -548,6 +559,7 @@ def HSin : DXILOp<19, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def HTan : DXILOp<20, unary> {
@@ -558,6 +570,7 @@ def HTan : DXILOp<20, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Exp2 : DXILOp<21, unary> {
@@ -569,6 +582,7 @@ def Exp2 : DXILOp<21, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Frac : DXILOp<22, unary> {
@@ -580,6 +594,7 @@ def Frac : DXILOp<22, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Log2 : DXILOp<23, unary> {
@@ -590,6 +605,7 @@ def Log2 : DXILOp<23, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Sqrt : DXILOp<24, unary> {
@@ -601,6 +617,7 @@ def Sqrt : DXILOp<24, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def RSqrt : DXILOp<25, unary> {
@@ -612,6 +629,7 @@ def RSqrt : DXILOp<25, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Round : DXILOp<26, unary> {
@@ -623,6 +641,7 @@ def Round : DXILOp<26, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Floor : DXILOp<27, unary> {
@@ -634,6 +653,7 @@ def Floor : DXILOp<27, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Ceil : DXILOp<28, unary> {
@@ -645,6 +665,7 @@ def Ceil : DXILOp<28, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Trunc : DXILOp<29, unary> {
@@ -655,6 +676,7 @@ def Trunc : DXILOp<29, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Rbits : DXILOp<30, unary> {
@@ -717,6 +739,7 @@ def FMax : DXILOp<35, binary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def FMin : DXILOp<36, binary> {
@@ -727,6 +750,7 @@ def FMin : DXILOp<36, binary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def SMax : DXILOp<37, binary> {
@@ -788,6 +812,7 @@ def FMad : DXILOp<46, tertiary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Fma : DXILOp<47, tertiary> {
@@ -798,6 +823,7 @@ def Fma : DXILOp<47, tertiary> {
   let overloads = [Overloads<DXIL1_0, [DoubleTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def IMad : DXILOp<48, tertiary> {
@@ -831,6 +857,7 @@ def Dot2 : DXILOp<54, dot2> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Dot3 : DXILOp<55, dot3> {
@@ -842,6 +869,7 @@ def Dot3 : DXILOp<55, dot3> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def Dot4 : DXILOp<56, dot4> {
@@ -853,6 +881,7 @@ def Dot4 : DXILOp<56, dot4> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+  let precise = 1;
 }
 
 def CreateHandle : DXILOp<57, createHandle> {
@@ -1107,6 +1136,7 @@ def WaveReadLaneAt : DXILOp<117, waveReadLaneAt> {
   let overloads = [Overloads<
       DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int1Ty, Int16Ty, Int32Ty, Int64Ty]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
+  let precise = 1;
 }
 
 def WaveActiveOp : DXILOp<119, waveActiveOp> {
@@ -1261,6 +1291,7 @@ def LegacyF32ToF16 : DXILOp<130, legacyF32ToF16> {
   let arguments = [FloatTy];
   let result = Int32Ty;
   let stages = [Stages<DXIL1_0, [all_stages]>];
+  let precise = 1;
 }
 
 def LegacyF16ToF32 : DXILOp<131, legacyF16ToF32> {
@@ -1270,6 +1301,7 @@ def LegacyF16ToF32 : DXILOp<131, legacyF16ToF32> {
   let arguments = [Int32Ty];
   let result = FloatTy;
   let stages = [Stages<DXIL1_0, [all_stages]>];
+  let precise = 1;
 }
 
 def WaveAllBitCount : DXILOp<135, waveAllOp> {
@@ -1325,6 +1357,7 @@ def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
   let overloads = [Overloads<DXIL1_4, [FloatTy]>];
   let stages = [Stages<DXIL1_4, [all_stages]>];
   let attributes = [Attributes<DXIL1_4, [ReadNone]>];
+  let precise = 1;
 }
 
 def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 1f41d2457e5bc..e0a9f78a18a26 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -22,6 +22,12 @@ using namespace llvm::dxil;
 constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
 
 namespace {
+
+static cl::opt<bool>
+    OpPreciseEnabled("enable-precise",
+                     cl::desc("Enables emission of dx.precise"),
+                     cl::init(false));
+
 enum OverloadKind : uint16_t {
   UNDEFINED = 0,
   VOID = 1,
@@ -155,6 +161,7 @@ struct OpCodeProperty {
   llvm::SmallVector<OpStage> Stages;
   int OverloadParamIndex; // parameter index which control the overload.
                           // When < 0, should be only 1 overload type.
+  bool Precise;
 };
 
 // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
@@ -484,6 +491,35 @@ static void setDXILAttributes(CallInst *CI, dxil::OpCode OpCode,
   return;
 }
 
+static bool isOverloadTyFloat(uint16_t ValidTyMask) {
+  if (ValidTyMask == OverloadKind::UNDEFINED)
+    return false;
+  return (ValidTyMask &
+          ((uint16_t)OverloadKind::HALF | (uint16_t)OverloadKind::FLOAT |
+           (uint16_t)OverloadKind::DOUBLE)) != 0;
+}
+
+static void setDXILMetadata(CallInst *CI, const OpCodeProperty *Prop) {
+  if (OpPreciseEnabled) {
+    bool AllOverloadAreFloat = false;
+    for (const OpOverload &Overload : Prop->Overloads)
+      AllOverloadAreFloat =
+          AllOverloadAreFloat || isOverloadTyFloat(Overload.ValidTys);
+
+    if (AllOverloadAreFloat && Prop->Precise) {
+      const StringRef Key = "dx.precise";
+      Module *M = CI->getModule();
+
+      LLVMContext &Ctx = M->getContext();
+      MDNode *One =
+          llvm::MDNode::get(Ctx, ConstantAsMetadata::get(ConstantInt::get(
+                                     llvm::Type::getInt32Ty(Ctx), 1)));
+
+      CI->setMetadata(Key, One);
+    }
+  }
+}
+
 namespace llvm {
 namespace dxil {
 
@@ -583,6 +619,7 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
   // We then need to attach available function attributes
   setDXILAttributes(CI, OpCode, DXILVersion);
 
+  setDXILMetadata(CI, Prop);
   return CI;
 }
 
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 94719598dfd58..26f7ed7f7f502 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -42,6 +42,7 @@ struct DXILOperationDesc {
   int OpCode;         // ID of DXIL operation
   StringRef OpClass;  // name of the opcode class
   StringRef Doc;      // the documentation description of this instruction
+  bool CanUsePrecise; // Can this operation be maker with dx.precise
   // Vector of operand type records - return type is at index 0
   SmallVector<const Record *> OpTypes;
   SmallVector<const Record *> OverloadRecs;
@@ -107,6 +108,7 @@ static StringRef GetIntrinsicName(const RecordVal *RV) {
 DXILOperationDesc::DXILOperationDesc(const Record *R) {
   OpName = R->getNameInitAsString();
   OpCode = R->getValueAsInt("OpCode");
+  CanUsePrecise = R->getValueAsBit("precise");
 
   Doc = R->getValueAsString("Doc");
   SmallVector<const Record *> ParamTypeRecs;
@@ -507,7 +509,7 @@ static void emitDXILOperationTable(ArrayRef<DXILOperationDesc> Ops,
        << OpClassStrings.get(Op.OpClass.data()) << ", "
        << getOverloadMaskString(Op.OverloadRecs) << ", "
        << getStageMaskString(Op.StageRecs) << ", " << Op.OverloadParamIndex
-       << " }";
+       << ", " << (Op.CanUsePrecise ? "true" : "false") << " }";
     Prefix = ",\n";
   }
   OS << "  };\n";

>From 49f1c2064f6565f7795e0daf1145d8c1e5a82c74 Mon Sep 17 00:00:00 2001
From: Joao Saffran <joaosaffranllvm at gmail.com>
Date: Thu, 16 Apr 2026 13:13:26 -0700
Subject: [PATCH 2/5] add test

---
 .../test/CodeGen/DirectX/Metadata/dx_precise.ll | 17 +++++++++++++++++
 1 file changed, 17 insertions(+)
 create mode 100644 llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll

diff --git a/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll b/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
new file mode 100644
index 0000000000000..956cb00dbc1c5
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
@@ -0,0 +1,17 @@
+; RUN: llc %s -enable-precise --filetype=asm -o - < %s 2>&1 | FileCheck %s --check-prefixes=ENABLED
+; RUN: llc %s --filetype=asm -o - < %s 2>&1 | FileCheck %s --check-prefixes=DISABLED
+
+
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+; ENABLED: call float @dx.op.unary.f32(i32 12, float %conv)
+; ENABLED-SAME: !dx.precise ![[SM:[0-9]+]]
+; ENABLED: ![[SM]] = !{i32 1}
+
+; DISABLED-NOT: !dx.precise ![[SM:[0-9]+]]
+define void @main(float %conv) {
+entry:
+  %1 = call float @llvm.cos.f32(float %conv)
+  ret void
+}

>From 41a49118731f6b178a73dae42afe359dadb3324d Mon Sep 17 00:00:00 2001
From: Joao Saffran <joaosaffranllvm at gmail.com>
Date: Thu, 16 Apr 2026 13:15:09 -0700
Subject: [PATCH 3/5] add comment

---
 llvm/lib/Target/DirectX/DXIL.td | 1 +
 1 file changed, 1 insertion(+)

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 8633d1b9f43d1..2981fa8721f65 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -406,6 +406,7 @@ class DXILOp<int opcode, DXILOpClass opclass> {
   // Versioned attributes of operation
   list<Attributes> attributes = [];
 
+  // Emmit dx.precise to prevent optimizations
   bit precise = 0;
 }
 

>From 5502c2d358291b4a5478668529cdbfd2a94d4c26 Mon Sep 17 00:00:00 2001
From: Joao Saffran <joaosaffranllvm at gmail.com>
Date: Thu, 16 Apr 2026 15:57:43 -0700
Subject: [PATCH 4/5] address damyan comments

---
 llvm/lib/Target/DirectX/DXIL.td               |  5 +-
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp     | 10 +--
 .../CodeGen/DirectX/Metadata/dx_precise.ll    | 69 ++++++++++++++++++-
 llvm/utils/TableGen/DXILEmitter.cpp           |  2 +-
 4 files changed, 73 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 2981fa8721f65..47dca39679d3f 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -406,7 +406,7 @@ class DXILOp<int opcode, DXILOpClass opclass> {
   // Versioned attributes of operation
   list<Attributes> attributes = [];
 
-  // Emmit dx.precise to prevent optimizations
+  // Does the operation support precise calculation?
   bit precise = 0;
 }
 
@@ -422,7 +422,6 @@ def Abs : DXILOp<6, unary> {
   let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
-  let precise = 1;
 }
 
 def Saturate : DXILOp<7, unary> {
@@ -1302,7 +1301,6 @@ def LegacyF16ToF32 : DXILOp<131, legacyF16ToF32> {
   let arguments = [Int32Ty];
   let result = FloatTy;
   let stages = [Stages<DXIL1_0, [all_stages]>];
-  let precise = 1;
 }
 
 def WaveAllBitCount : DXILOp<135, waveAllOp> {
@@ -1358,7 +1356,6 @@ def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
   let overloads = [Overloads<DXIL1_4, [FloatTy]>];
   let stages = [Stages<DXIL1_4, [all_stages]>];
   let attributes = [Attributes<DXIL1_4, [ReadNone]>];
-  let precise = 1;
 }
 
 def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index e0a9f78a18a26..c5ffed559b16e 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -500,13 +500,13 @@ static bool isOverloadTyFloat(uint16_t ValidTyMask) {
 }
 
 static void setDXILMetadata(CallInst *CI, const OpCodeProperty *Prop) {
-  if (OpPreciseEnabled) {
-    bool AllOverloadAreFloat = false;
+  if (OpPreciseEnabled && Prop->Precise) {
+    bool HasFloatOverload = false;
     for (const OpOverload &Overload : Prop->Overloads)
-      AllOverloadAreFloat =
-          AllOverloadAreFloat || isOverloadTyFloat(Overload.ValidTys);
+      HasFloatOverload =
+          HasFloatOverload || isOverloadTyFloat(Overload.ValidTys);
 
-    if (AllOverloadAreFloat && Prop->Precise) {
+    if (HasFloatOverload) {
       const StringRef Key = "dx.precise";
       Module *M = CI->getModule();
 
diff --git a/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll b/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
index 956cb00dbc1c5..e2dc635b7b1db 100644
--- a/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
+++ b/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
@@ -5,13 +5,76 @@
 
 target triple = "dxil-pc-shadermodel6.6-compute"
 
-; ENABLED: call float @dx.op.unary.f32(i32 12, float %conv)
+; ENABLED: call float @dx.op.unary.f32(i32 7,
 ; ENABLED-SAME: !dx.precise ![[SM:[0-9]+]]
+; ENABLED-COUNT-26: !dx.precise ![[SM]]
+; ENABLED-NOT: !dx.precise ![[SM]]
 ; ENABLED: ![[SM]] = !{i32 1}
 
 ; DISABLED-NOT: !dx.precise ![[SM:[0-9]+]]
-define void @main(float %conv) {
+define void @unary(float %p) {
 entry:
-  %1 = call float @llvm.cos.f32(float %conv)
+  %1 = call float @llvm.dx.saturate.f32(float %p)
+  %2 = call float @llvm.cos.f32(float %p)
+  %3 = call float @llvm.sin.f32(float %p)
+  %4 = call float @llvm.tan.f32(float %p)
+  %5 = call float @llvm.acos.f32(float %p)
+  %6 = call float @llvm.asin.f32(float %p)
+  %7 = call float @llvm.atan.f32(float %p)
+  %8 = call float @llvm.cosh.f32(float %p)
+  %9 = call float @llvm.sinh.f32(float %p)
+  %10 = call float @llvm.tanh.f32(float %p)
+  %11 = call float @llvm.exp2.f32(float %p)
+  %12 = call float @llvm.dx.frac.f32(float %p)
+  %13 = call float @llvm.log2.f32(float %p)
+  %14 = call float @llvm.log2.f32(float %p)
+  %15 = call float @llvm.sqrt.f32(float %p)
+  %16 = call float @llvm.roundeven.f32(float %p)
+  %17 = call float @llvm.floor.f32(float %p)
+  %18 = call float @llvm.ceil.f32(float %p)
+  %19 = call float @llvm.trunc.f32(float %p)
+  ret void
+}
+
+define void @binary(float %p1, float %p2) {
+entry:
+  %20 = call float @llvm.maxnum.f32(float %p1, float %p2)
+  %21 = call float @llvm.minnum.f32(float %p1, float %p2)
+  ret void
+}
+
+define void @tertiary(float %p1, float %p2, float %p3) {
+entry:
+  %22 = call float @llvm.fmuladd.f32(float %p1, float %p2, float %p3)
+  ret void
+}
+
+define void @fma(double %p1, double %p2, double %p3) {
+entry:
+  %23 = call double @llvm.fma.f64(double %p1, double %p2, double %p3)
+  ret void
+}
+
+define void @dot2(<2 x float> %a, <2 x float> %b) {
+entry:
+  %24 = call float @llvm.dx.fdot.v2f32(<2 x float> %a, <2 x float> %b)
+  ret void
+}
+
+define void @dot3(<3 x float> %a, <3 x float> %b) {
+entry:
+  %25 = call float @llvm.dx.fdot.v3f32(<3 x float> %a, <3 x float> %b)
+  ret void
+}
+
+define void @dot4(<4 x float> %a, <4 x float> %b) {
+entry:
+  %26 = call float @llvm.dx.fdot.v4f32(<4 x float> %a, <4 x float> %b)
+  ret void
+}
+
+define void @wave_rla(float %expr, i32 %idx) {
+entry:
+  %27 = call float @llvm.dx.wave.readlane(float %expr, i32 %idx)
   ret void
 }
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 26f7ed7f7f502..e32c1f4485469 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -42,7 +42,7 @@ struct DXILOperationDesc {
   int OpCode;         // ID of DXIL operation
   StringRef OpClass;  // name of the opcode class
   StringRef Doc;      // the documentation description of this instruction
-  bool CanUsePrecise; // Can this operation be maker with dx.precise
+  bool CanUsePrecise; // Can this operation be made with dx.precise
   // Vector of operand type records - return type is at index 0
   SmallVector<const Record *> OpTypes;
   SmallVector<const Record *> OverloadRecs;

>From e07febc0ffb72cfc9304503956d4966a5d69eb62 Mon Sep 17 00:00:00 2001
From: Joao Saffran <joaosaffranllvm at gmail.com>
Date: Fri, 17 Apr 2026 11:17:15 -0700
Subject: [PATCH 5/5] simplify and improve testing

---
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp     | 38 +++----
 .../CodeGen/DirectX/Metadata/dx_precise.ll    | 98 ++++++++++++++++---
 2 files changed, 98 insertions(+), 38 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index c5ffed559b16e..414da9b26b9d1 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -491,32 +491,20 @@ static void setDXILAttributes(CallInst *CI, dxil::OpCode OpCode,
   return;
 }
 
-static bool isOverloadTyFloat(uint16_t ValidTyMask) {
-  if (ValidTyMask == OverloadKind::UNDEFINED)
-    return false;
-  return (ValidTyMask &
-          ((uint16_t)OverloadKind::HALF | (uint16_t)OverloadKind::FLOAT |
-           (uint16_t)OverloadKind::DOUBLE)) != 0;
-}
-
 static void setDXILMetadata(CallInst *CI, const OpCodeProperty *Prop) {
-  if (OpPreciseEnabled && Prop->Precise) {
-    bool HasFloatOverload = false;
-    for (const OpOverload &Overload : Prop->Overloads)
-      HasFloatOverload =
-          HasFloatOverload || isOverloadTyFloat(Overload.ValidTys);
-
-    if (HasFloatOverload) {
-      const StringRef Key = "dx.precise";
-      Module *M = CI->getModule();
-
-      LLVMContext &Ctx = M->getContext();
-      MDNode *One =
-          llvm::MDNode::get(Ctx, ConstantAsMetadata::get(ConstantInt::get(
-                                     llvm::Type::getInt32Ty(Ctx), 1)));
-
-      CI->setMetadata(Key, One);
-    }
+  if (OpPreciseEnabled &&
+      Prop->Precise &
+          CI->getFunctionType()->getReturnType()->isFloatingPointTy()) {
+
+    const StringRef Key = "dx.precise";
+    Module *M = CI->getModule();
+
+    LLVMContext &Ctx = M->getContext();
+    MDNode *One =
+        llvm::MDNode::get(Ctx, ConstantAsMetadata::get(ConstantInt::get(
+                                   llvm::Type::getInt32Ty(Ctx), 1)));
+
+    CI->setMetadata(Key, One);
   }
 }
 
diff --git a/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll b/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
index e2dc635b7b1db..2af993e4380f5 100644
--- a/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
+++ b/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
@@ -7,12 +7,12 @@ target triple = "dxil-pc-shadermodel6.6-compute"
 
 ; ENABLED: call float @dx.op.unary.f32(i32 7,
 ; ENABLED-SAME: !dx.precise ![[SM:[0-9]+]]
-; ENABLED-COUNT-26: !dx.precise ![[SM]]
+; ENABLED-COUNT-52: !dx.precise ![[SM]]
 ; ENABLED-NOT: !dx.precise ![[SM]]
 ; ENABLED: ![[SM]] = !{i32 1}
 
 ; DISABLED-NOT: !dx.precise ![[SM:[0-9]+]]
-define void @unary(float %p) {
+define void @unary_f32(float %p) {
 entry:
   %1 = call float @llvm.dx.saturate.f32(float %p)
   %2 = call float @llvm.cos.f32(float %p)
@@ -27,22 +27,64 @@ entry:
   %11 = call float @llvm.exp2.f32(float %p)
   %12 = call float @llvm.dx.frac.f32(float %p)
   %13 = call float @llvm.log2.f32(float %p)
-  %14 = call float @llvm.log2.f32(float %p)
-  %15 = call float @llvm.sqrt.f32(float %p)
-  %16 = call float @llvm.roundeven.f32(float %p)
-  %17 = call float @llvm.floor.f32(float %p)
-  %18 = call float @llvm.ceil.f32(float %p)
-  %19 = call float @llvm.trunc.f32(float %p)
+  %14 = call float @llvm.sqrt.f32(float %p)
+  %15 = call float @llvm.roundeven.f32(float %p)
+  %16 = call float @llvm.floor.f32(float %p)
+  %17 = call float @llvm.ceil.f32(float %p)
+  %18 = call float @llvm.trunc.f32(float %p)
   ret void
 }
 
-define void @binary(float %p1, float %p2) {
+define void @unary_f16(half %p) {
+entry:
+  %1 = call half @llvm.dx.saturate.f16(half %p)
+  %2 = call half @llvm.cos.f16(half %p)
+  %3 = call half @llvm.sin.f16(half %p)
+  %4 = call half @llvm.tan.f16(half %p)
+  %5 = call half @llvm.acos.f16(half %p)
+  %6 = call half @llvm.asin.f16(half %p)
+  %7 = call half @llvm.atan.f16(half %p)
+  %8 = call half @llvm.cosh.f16(half %p)
+  %9 = call half @llvm.sinh.f16(half %p)
+  %10 = call half @llvm.tanh.f16(half %p)
+  %11 = call half @llvm.exp2.f16(half %p)
+  %12 = call half @llvm.dx.frac.f16(half %p)
+  %13 = call half @llvm.log2.f16(half %p)
+  %14 = call half @llvm.sqrt.f16(half %p)
+  %15 = call half @llvm.roundeven.f16(half %p)
+  %16 = call half @llvm.floor.f16(half %p)
+  %17 = call half @llvm.ceil.f16(half %p)
+  %18 = call half @llvm.trunc.f16(half %p)
+  ret void
+}
+
+define void @unary_f64(double %p) {
+entry:
+  %1 = call double @llvm.dx.saturate.f64(double %p)
+  ret void
+}
+
+define void @binary_f32(float %p1, float %p2) {
 entry:
   %20 = call float @llvm.maxnum.f32(float %p1, float %p2)
   %21 = call float @llvm.minnum.f32(float %p1, float %p2)
   ret void
 }
 
+define void @binary_f16(half %p1, half %p2) {
+entry:
+  %20 = call half @llvm.maxnum.f16(half %p1, half %p2)
+  %21 = call half @llvm.minnum.f16(half %p1, half %p2)
+  ret void
+}
+
+define void @binary_f64(double %p1, double %p2) {
+entry:
+  %20 = call double @llvm.maxnum.f64(double %p1, double %p2)
+  %21 = call double @llvm.minnum.f64(double %p1, double %p2)
+  ret void
+}
+
 define void @tertiary(float %p1, float %p2, float %p3) {
 entry:
   %22 = call float @llvm.fmuladd.f32(float %p1, float %p2, float %p3)
@@ -55,26 +97,56 @@ entry:
   ret void
 }
 
-define void @dot2(<2 x float> %a, <2 x float> %b) {
+define void @dot2_f32(<2 x float> %a, <2 x float> %b) {
 entry:
   %24 = call float @llvm.dx.fdot.v2f32(<2 x float> %a, <2 x float> %b)
   ret void
 }
 
-define void @dot3(<3 x float> %a, <3 x float> %b) {
+define void @dot2_f16(<2 x half> %a, <2 x half> %b) {
+entry:
+  %24 = call half @llvm.dx.fdot.v2f16(<2 x half> %a, <2 x half> %b)
+  ret void
+}
+
+define void @dot3_f32(<3 x float> %a, <3 x float> %b) {
 entry:
   %25 = call float @llvm.dx.fdot.v3f32(<3 x float> %a, <3 x float> %b)
   ret void
 }
 
-define void @dot4(<4 x float> %a, <4 x float> %b) {
+define void @dot3_16(<3 x half> %a, <3 x half> %b) {
+entry:
+  %25 = call half @llvm.dx.fdot.v3f16(<3 x half> %a, <3 x half> %b)
+  ret void
+}
+
+define void @dot4_f32(<4 x float> %a, <4 x float> %b) {
 entry:
   %26 = call float @llvm.dx.fdot.v4f32(<4 x float> %a, <4 x float> %b)
   ret void
 }
 
-define void @wave_rla(float %expr, i32 %idx) {
+define void @dot4_f16(<4 x half> %a, <4 x half> %b) {
+entry:
+  %26 = call half @llvm.dx.fdot.v4f16(<4 x half> %a, <4 x half> %b)
+  ret void
+}
+
+define void @wave_rla_f32(float %expr, i32 %idx) {
 entry:
   %27 = call float @llvm.dx.wave.readlane(float %expr, i32 %idx)
   ret void
 }
+
+define void @wave_rla_f16(half %expr, i32 %idx) {
+entry:
+  %27 = call half @llvm.dx.wave.readlane(half %expr, i32 %idx)
+  ret void
+}
+
+define void @wave_rla_i32(i32 %expr, i32 %idx) {
+entry:
+  %27 = call i32 @llvm.dx.wave.readlane(i32 %expr, i32 %idx)
+  ret void
+}



More information about the cfe-commits mailing list