[clang] [llvm] [DirectX] emmit `dx.precise` metadata when using `-Gis` flag (PR #192526)
via cfe-commits
cfe-commits at lists.llvm.org
Fri Apr 17 11:17:32 PDT 2026
https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/192526
>From 7e74f5fa910771a7e90dbdb7816d97de2e34302d Mon Sep 17 00:00:00 2001
From: Joao Saffran <joaosaffranllvm at gmail.com>
Date: Thu, 16 Apr 2026 11:55:02 -0700
Subject: [PATCH 1/5] add precise support to intrinsics
---
clang/lib/Driver/ToolChains/HLSL.cpp | 3 ++
llvm/lib/Target/DirectX/DXIL.td | 33 ++++++++++++++++++++
llvm/lib/Target/DirectX/DXILOpBuilder.cpp | 37 +++++++++++++++++++++++
llvm/utils/TableGen/DXILEmitter.cpp | 4 ++-
4 files changed, 76 insertions(+), 1 deletion(-)
diff --git a/clang/lib/Driver/ToolChains/HLSL.cpp b/clang/lib/Driver/ToolChains/HLSL.cpp
index 834b8acc78734..163a66a9703ee 100644
--- a/clang/lib/Driver/ToolChains/HLSL.cpp
+++ b/clang/lib/Driver/ToolChains/HLSL.cpp
@@ -512,6 +512,9 @@ HLSLToolChain::TranslateArgs(const DerivedArgList &Args, StringRef BoundArch,
// Translate -Gis into -ffp_model_EQ=strict
DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_ffp_model_EQ),
"strict");
+
+ DAL->AddSeparateArg(nullptr, Opts.getOption(options::OPT_mllvm),
+ StringRef("-enable-precise"));
A->claim();
continue;
}
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 0a1e0114aa3bb..8633d1b9f43d1 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -405,6 +405,8 @@ class DXILOp<int opcode, DXILOpClass opclass> {
// Versioned attributes of operation
list<Attributes> attributes = [];
+
+ bit precise = 0;
}
// Concrete definitions of DXIL Operations
@@ -419,6 +421,7 @@ def Abs : DXILOp<6, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Saturate : DXILOp<7, unary> {
@@ -430,6 +433,7 @@ def Saturate : DXILOp<7, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def IsNaN : DXILOp<8, isSpecialFloat> {
@@ -478,6 +482,7 @@ def Cos : DXILOp<12, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Sin : DXILOp<13, unary> {
@@ -488,6 +493,7 @@ def Sin : DXILOp<13, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Tan : DXILOp<14, unary> {
@@ -498,6 +504,7 @@ def Tan : DXILOp<14, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def ACos : DXILOp<15, unary> {
@@ -508,6 +515,7 @@ def ACos : DXILOp<15, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def ASin : DXILOp<16, unary> {
@@ -518,6 +526,7 @@ def ASin : DXILOp<16, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def ATan : DXILOp<17, unary> {
@@ -528,6 +537,7 @@ def ATan : DXILOp<17, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def HCos : DXILOp<18, unary> {
@@ -538,6 +548,7 @@ def HCos : DXILOp<18, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def HSin : DXILOp<19, unary> {
@@ -548,6 +559,7 @@ def HSin : DXILOp<19, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def HTan : DXILOp<20, unary> {
@@ -558,6 +570,7 @@ def HTan : DXILOp<20, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Exp2 : DXILOp<21, unary> {
@@ -569,6 +582,7 @@ def Exp2 : DXILOp<21, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Frac : DXILOp<22, unary> {
@@ -580,6 +594,7 @@ def Frac : DXILOp<22, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Log2 : DXILOp<23, unary> {
@@ -590,6 +605,7 @@ def Log2 : DXILOp<23, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Sqrt : DXILOp<24, unary> {
@@ -601,6 +617,7 @@ def Sqrt : DXILOp<24, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def RSqrt : DXILOp<25, unary> {
@@ -612,6 +629,7 @@ def RSqrt : DXILOp<25, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Round : DXILOp<26, unary> {
@@ -623,6 +641,7 @@ def Round : DXILOp<26, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Floor : DXILOp<27, unary> {
@@ -634,6 +653,7 @@ def Floor : DXILOp<27, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Ceil : DXILOp<28, unary> {
@@ -645,6 +665,7 @@ def Ceil : DXILOp<28, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Trunc : DXILOp<29, unary> {
@@ -655,6 +676,7 @@ def Trunc : DXILOp<29, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Rbits : DXILOp<30, unary> {
@@ -717,6 +739,7 @@ def FMax : DXILOp<35, binary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def FMin : DXILOp<36, binary> {
@@ -727,6 +750,7 @@ def FMin : DXILOp<36, binary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def SMax : DXILOp<37, binary> {
@@ -788,6 +812,7 @@ def FMad : DXILOp<46, tertiary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Fma : DXILOp<47, tertiary> {
@@ -798,6 +823,7 @@ def Fma : DXILOp<47, tertiary> {
let overloads = [Overloads<DXIL1_0, [DoubleTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def IMad : DXILOp<48, tertiary> {
@@ -831,6 +857,7 @@ def Dot2 : DXILOp<54, dot2> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Dot3 : DXILOp<55, dot3> {
@@ -842,6 +869,7 @@ def Dot3 : DXILOp<55, dot3> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def Dot4 : DXILOp<56, dot4> {
@@ -853,6 +881,7 @@ def Dot4 : DXILOp<56, dot4> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+ let precise = 1;
}
def CreateHandle : DXILOp<57, createHandle> {
@@ -1107,6 +1136,7 @@ def WaveReadLaneAt : DXILOp<117, waveReadLaneAt> {
let overloads = [Overloads<
DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int1Ty, Int16Ty, Int32Ty, Int64Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
+ let precise = 1;
}
def WaveActiveOp : DXILOp<119, waveActiveOp> {
@@ -1261,6 +1291,7 @@ def LegacyF32ToF16 : DXILOp<130, legacyF32ToF16> {
let arguments = [FloatTy];
let result = Int32Ty;
let stages = [Stages<DXIL1_0, [all_stages]>];
+ let precise = 1;
}
def LegacyF16ToF32 : DXILOp<131, legacyF16ToF32> {
@@ -1270,6 +1301,7 @@ def LegacyF16ToF32 : DXILOp<131, legacyF16ToF32> {
let arguments = [Int32Ty];
let result = FloatTy;
let stages = [Stages<DXIL1_0, [all_stages]>];
+ let precise = 1;
}
def WaveAllBitCount : DXILOp<135, waveAllOp> {
@@ -1325,6 +1357,7 @@ def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
let overloads = [Overloads<DXIL1_4, [FloatTy]>];
let stages = [Stages<DXIL1_4, [all_stages]>];
let attributes = [Attributes<DXIL1_4, [ReadNone]>];
+ let precise = 1;
}
def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 1f41d2457e5bc..e0a9f78a18a26 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -22,6 +22,12 @@ using namespace llvm::dxil;
constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
namespace {
+
+static cl::opt<bool>
+ OpPreciseEnabled("enable-precise",
+ cl::desc("Enables emission of dx.precise"),
+ cl::init(false));
+
enum OverloadKind : uint16_t {
UNDEFINED = 0,
VOID = 1,
@@ -155,6 +161,7 @@ struct OpCodeProperty {
llvm::SmallVector<OpStage> Stages;
int OverloadParamIndex; // parameter index which control the overload.
// When < 0, should be only 1 overload type.
+ bool Precise;
};
// Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
@@ -484,6 +491,35 @@ static void setDXILAttributes(CallInst *CI, dxil::OpCode OpCode,
return;
}
+static bool isOverloadTyFloat(uint16_t ValidTyMask) {
+ if (ValidTyMask == OverloadKind::UNDEFINED)
+ return false;
+ return (ValidTyMask &
+ ((uint16_t)OverloadKind::HALF | (uint16_t)OverloadKind::FLOAT |
+ (uint16_t)OverloadKind::DOUBLE)) != 0;
+}
+
+static void setDXILMetadata(CallInst *CI, const OpCodeProperty *Prop) {
+ if (OpPreciseEnabled) {
+ bool AllOverloadAreFloat = false;
+ for (const OpOverload &Overload : Prop->Overloads)
+ AllOverloadAreFloat =
+ AllOverloadAreFloat || isOverloadTyFloat(Overload.ValidTys);
+
+ if (AllOverloadAreFloat && Prop->Precise) {
+ const StringRef Key = "dx.precise";
+ Module *M = CI->getModule();
+
+ LLVMContext &Ctx = M->getContext();
+ MDNode *One =
+ llvm::MDNode::get(Ctx, ConstantAsMetadata::get(ConstantInt::get(
+ llvm::Type::getInt32Ty(Ctx), 1)));
+
+ CI->setMetadata(Key, One);
+ }
+ }
+}
+
namespace llvm {
namespace dxil {
@@ -583,6 +619,7 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
// We then need to attach available function attributes
setDXILAttributes(CI, OpCode, DXILVersion);
+ setDXILMetadata(CI, Prop);
return CI;
}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 94719598dfd58..26f7ed7f7f502 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -42,6 +42,7 @@ struct DXILOperationDesc {
int OpCode; // ID of DXIL operation
StringRef OpClass; // name of the opcode class
StringRef Doc; // the documentation description of this instruction
+ bool CanUsePrecise; // Can this operation be maker with dx.precise
// Vector of operand type records - return type is at index 0
SmallVector<const Record *> OpTypes;
SmallVector<const Record *> OverloadRecs;
@@ -107,6 +108,7 @@ static StringRef GetIntrinsicName(const RecordVal *RV) {
DXILOperationDesc::DXILOperationDesc(const Record *R) {
OpName = R->getNameInitAsString();
OpCode = R->getValueAsInt("OpCode");
+ CanUsePrecise = R->getValueAsBit("precise");
Doc = R->getValueAsString("Doc");
SmallVector<const Record *> ParamTypeRecs;
@@ -507,7 +509,7 @@ static void emitDXILOperationTable(ArrayRef<DXILOperationDesc> Ops,
<< OpClassStrings.get(Op.OpClass.data()) << ", "
<< getOverloadMaskString(Op.OverloadRecs) << ", "
<< getStageMaskString(Op.StageRecs) << ", " << Op.OverloadParamIndex
- << " }";
+ << ", " << (Op.CanUsePrecise ? "true" : "false") << " }";
Prefix = ",\n";
}
OS << " };\n";
>From 49f1c2064f6565f7795e0daf1145d8c1e5a82c74 Mon Sep 17 00:00:00 2001
From: Joao Saffran <joaosaffranllvm at gmail.com>
Date: Thu, 16 Apr 2026 13:13:26 -0700
Subject: [PATCH 2/5] add test
---
.../test/CodeGen/DirectX/Metadata/dx_precise.ll | 17 +++++++++++++++++
1 file changed, 17 insertions(+)
create mode 100644 llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
diff --git a/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll b/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
new file mode 100644
index 0000000000000..956cb00dbc1c5
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
@@ -0,0 +1,17 @@
+; RUN: llc %s -enable-precise --filetype=asm -o - < %s 2>&1 | FileCheck %s --check-prefixes=ENABLED
+; RUN: llc %s --filetype=asm -o - < %s 2>&1 | FileCheck %s --check-prefixes=DISABLED
+
+
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+; ENABLED: call float @dx.op.unary.f32(i32 12, float %conv)
+; ENABLED-SAME: !dx.precise ![[SM:[0-9]+]]
+; ENABLED: ![[SM]] = !{i32 1}
+
+; DISABLED-NOT: !dx.precise ![[SM:[0-9]+]]
+define void @main(float %conv) {
+entry:
+ %1 = call float @llvm.cos.f32(float %conv)
+ ret void
+}
>From 41a49118731f6b178a73dae42afe359dadb3324d Mon Sep 17 00:00:00 2001
From: Joao Saffran <joaosaffranllvm at gmail.com>
Date: Thu, 16 Apr 2026 13:15:09 -0700
Subject: [PATCH 3/5] add comment
---
llvm/lib/Target/DirectX/DXIL.td | 1 +
1 file changed, 1 insertion(+)
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 8633d1b9f43d1..2981fa8721f65 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -406,6 +406,7 @@ class DXILOp<int opcode, DXILOpClass opclass> {
// Versioned attributes of operation
list<Attributes> attributes = [];
+ // Emmit dx.precise to prevent optimizations
bit precise = 0;
}
>From 5502c2d358291b4a5478668529cdbfd2a94d4c26 Mon Sep 17 00:00:00 2001
From: Joao Saffran <joaosaffranllvm at gmail.com>
Date: Thu, 16 Apr 2026 15:57:43 -0700
Subject: [PATCH 4/5] address damyan comments
---
llvm/lib/Target/DirectX/DXIL.td | 5 +-
llvm/lib/Target/DirectX/DXILOpBuilder.cpp | 10 +--
.../CodeGen/DirectX/Metadata/dx_precise.ll | 69 ++++++++++++++++++-
llvm/utils/TableGen/DXILEmitter.cpp | 2 +-
4 files changed, 73 insertions(+), 13 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 2981fa8721f65..47dca39679d3f 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -406,7 +406,7 @@ class DXILOp<int opcode, DXILOpClass opclass> {
// Versioned attributes of operation
list<Attributes> attributes = [];
- // Emmit dx.precise to prevent optimizations
+ // Does the operation support precise calculation?
bit precise = 0;
}
@@ -422,7 +422,6 @@ def Abs : DXILOp<6, unary> {
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
- let precise = 1;
}
def Saturate : DXILOp<7, unary> {
@@ -1302,7 +1301,6 @@ def LegacyF16ToF32 : DXILOp<131, legacyF16ToF32> {
let arguments = [Int32Ty];
let result = FloatTy;
let stages = [Stages<DXIL1_0, [all_stages]>];
- let precise = 1;
}
def WaveAllBitCount : DXILOp<135, waveAllOp> {
@@ -1358,7 +1356,6 @@ def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
let overloads = [Overloads<DXIL1_4, [FloatTy]>];
let stages = [Stages<DXIL1_4, [all_stages]>];
let attributes = [Attributes<DXIL1_4, [ReadNone]>];
- let precise = 1;
}
def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> {
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index e0a9f78a18a26..c5ffed559b16e 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -500,13 +500,13 @@ static bool isOverloadTyFloat(uint16_t ValidTyMask) {
}
static void setDXILMetadata(CallInst *CI, const OpCodeProperty *Prop) {
- if (OpPreciseEnabled) {
- bool AllOverloadAreFloat = false;
+ if (OpPreciseEnabled && Prop->Precise) {
+ bool HasFloatOverload = false;
for (const OpOverload &Overload : Prop->Overloads)
- AllOverloadAreFloat =
- AllOverloadAreFloat || isOverloadTyFloat(Overload.ValidTys);
+ HasFloatOverload =
+ HasFloatOverload || isOverloadTyFloat(Overload.ValidTys);
- if (AllOverloadAreFloat && Prop->Precise) {
+ if (HasFloatOverload) {
const StringRef Key = "dx.precise";
Module *M = CI->getModule();
diff --git a/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll b/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
index 956cb00dbc1c5..e2dc635b7b1db 100644
--- a/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
+++ b/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
@@ -5,13 +5,76 @@
target triple = "dxil-pc-shadermodel6.6-compute"
-; ENABLED: call float @dx.op.unary.f32(i32 12, float %conv)
+; ENABLED: call float @dx.op.unary.f32(i32 7,
; ENABLED-SAME: !dx.precise ![[SM:[0-9]+]]
+; ENABLED-COUNT-26: !dx.precise ![[SM]]
+; ENABLED-NOT: !dx.precise ![[SM]]
; ENABLED: ![[SM]] = !{i32 1}
; DISABLED-NOT: !dx.precise ![[SM:[0-9]+]]
-define void @main(float %conv) {
+define void @unary(float %p) {
entry:
- %1 = call float @llvm.cos.f32(float %conv)
+ %1 = call float @llvm.dx.saturate.f32(float %p)
+ %2 = call float @llvm.cos.f32(float %p)
+ %3 = call float @llvm.sin.f32(float %p)
+ %4 = call float @llvm.tan.f32(float %p)
+ %5 = call float @llvm.acos.f32(float %p)
+ %6 = call float @llvm.asin.f32(float %p)
+ %7 = call float @llvm.atan.f32(float %p)
+ %8 = call float @llvm.cosh.f32(float %p)
+ %9 = call float @llvm.sinh.f32(float %p)
+ %10 = call float @llvm.tanh.f32(float %p)
+ %11 = call float @llvm.exp2.f32(float %p)
+ %12 = call float @llvm.dx.frac.f32(float %p)
+ %13 = call float @llvm.log2.f32(float %p)
+ %14 = call float @llvm.log2.f32(float %p)
+ %15 = call float @llvm.sqrt.f32(float %p)
+ %16 = call float @llvm.roundeven.f32(float %p)
+ %17 = call float @llvm.floor.f32(float %p)
+ %18 = call float @llvm.ceil.f32(float %p)
+ %19 = call float @llvm.trunc.f32(float %p)
+ ret void
+}
+
+define void @binary(float %p1, float %p2) {
+entry:
+ %20 = call float @llvm.maxnum.f32(float %p1, float %p2)
+ %21 = call float @llvm.minnum.f32(float %p1, float %p2)
+ ret void
+}
+
+define void @tertiary(float %p1, float %p2, float %p3) {
+entry:
+ %22 = call float @llvm.fmuladd.f32(float %p1, float %p2, float %p3)
+ ret void
+}
+
+define void @fma(double %p1, double %p2, double %p3) {
+entry:
+ %23 = call double @llvm.fma.f64(double %p1, double %p2, double %p3)
+ ret void
+}
+
+define void @dot2(<2 x float> %a, <2 x float> %b) {
+entry:
+ %24 = call float @llvm.dx.fdot.v2f32(<2 x float> %a, <2 x float> %b)
+ ret void
+}
+
+define void @dot3(<3 x float> %a, <3 x float> %b) {
+entry:
+ %25 = call float @llvm.dx.fdot.v3f32(<3 x float> %a, <3 x float> %b)
+ ret void
+}
+
+define void @dot4(<4 x float> %a, <4 x float> %b) {
+entry:
+ %26 = call float @llvm.dx.fdot.v4f32(<4 x float> %a, <4 x float> %b)
+ ret void
+}
+
+define void @wave_rla(float %expr, i32 %idx) {
+entry:
+ %27 = call float @llvm.dx.wave.readlane(float %expr, i32 %idx)
ret void
}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 26f7ed7f7f502..e32c1f4485469 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -42,7 +42,7 @@ struct DXILOperationDesc {
int OpCode; // ID of DXIL operation
StringRef OpClass; // name of the opcode class
StringRef Doc; // the documentation description of this instruction
- bool CanUsePrecise; // Can this operation be maker with dx.precise
+ bool CanUsePrecise; // Can this operation be made with dx.precise
// Vector of operand type records - return type is at index 0
SmallVector<const Record *> OpTypes;
SmallVector<const Record *> OverloadRecs;
>From e07febc0ffb72cfc9304503956d4966a5d69eb62 Mon Sep 17 00:00:00 2001
From: Joao Saffran <joaosaffranllvm at gmail.com>
Date: Fri, 17 Apr 2026 11:17:15 -0700
Subject: [PATCH 5/5] simplify and improve testing
---
llvm/lib/Target/DirectX/DXILOpBuilder.cpp | 38 +++----
.../CodeGen/DirectX/Metadata/dx_precise.ll | 98 ++++++++++++++++---
2 files changed, 98 insertions(+), 38 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index c5ffed559b16e..414da9b26b9d1 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -491,32 +491,20 @@ static void setDXILAttributes(CallInst *CI, dxil::OpCode OpCode,
return;
}
-static bool isOverloadTyFloat(uint16_t ValidTyMask) {
- if (ValidTyMask == OverloadKind::UNDEFINED)
- return false;
- return (ValidTyMask &
- ((uint16_t)OverloadKind::HALF | (uint16_t)OverloadKind::FLOAT |
- (uint16_t)OverloadKind::DOUBLE)) != 0;
-}
-
static void setDXILMetadata(CallInst *CI, const OpCodeProperty *Prop) {
- if (OpPreciseEnabled && Prop->Precise) {
- bool HasFloatOverload = false;
- for (const OpOverload &Overload : Prop->Overloads)
- HasFloatOverload =
- HasFloatOverload || isOverloadTyFloat(Overload.ValidTys);
-
- if (HasFloatOverload) {
- const StringRef Key = "dx.precise";
- Module *M = CI->getModule();
-
- LLVMContext &Ctx = M->getContext();
- MDNode *One =
- llvm::MDNode::get(Ctx, ConstantAsMetadata::get(ConstantInt::get(
- llvm::Type::getInt32Ty(Ctx), 1)));
-
- CI->setMetadata(Key, One);
- }
+ if (OpPreciseEnabled &&
+ Prop->Precise &
+ CI->getFunctionType()->getReturnType()->isFloatingPointTy()) {
+
+ const StringRef Key = "dx.precise";
+ Module *M = CI->getModule();
+
+ LLVMContext &Ctx = M->getContext();
+ MDNode *One =
+ llvm::MDNode::get(Ctx, ConstantAsMetadata::get(ConstantInt::get(
+ llvm::Type::getInt32Ty(Ctx), 1)));
+
+ CI->setMetadata(Key, One);
}
}
diff --git a/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll b/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
index e2dc635b7b1db..2af993e4380f5 100644
--- a/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
+++ b/llvm/test/CodeGen/DirectX/Metadata/dx_precise.ll
@@ -7,12 +7,12 @@ target triple = "dxil-pc-shadermodel6.6-compute"
; ENABLED: call float @dx.op.unary.f32(i32 7,
; ENABLED-SAME: !dx.precise ![[SM:[0-9]+]]
-; ENABLED-COUNT-26: !dx.precise ![[SM]]
+; ENABLED-COUNT-52: !dx.precise ![[SM]]
; ENABLED-NOT: !dx.precise ![[SM]]
; ENABLED: ![[SM]] = !{i32 1}
; DISABLED-NOT: !dx.precise ![[SM:[0-9]+]]
-define void @unary(float %p) {
+define void @unary_f32(float %p) {
entry:
%1 = call float @llvm.dx.saturate.f32(float %p)
%2 = call float @llvm.cos.f32(float %p)
@@ -27,22 +27,64 @@ entry:
%11 = call float @llvm.exp2.f32(float %p)
%12 = call float @llvm.dx.frac.f32(float %p)
%13 = call float @llvm.log2.f32(float %p)
- %14 = call float @llvm.log2.f32(float %p)
- %15 = call float @llvm.sqrt.f32(float %p)
- %16 = call float @llvm.roundeven.f32(float %p)
- %17 = call float @llvm.floor.f32(float %p)
- %18 = call float @llvm.ceil.f32(float %p)
- %19 = call float @llvm.trunc.f32(float %p)
+ %14 = call float @llvm.sqrt.f32(float %p)
+ %15 = call float @llvm.roundeven.f32(float %p)
+ %16 = call float @llvm.floor.f32(float %p)
+ %17 = call float @llvm.ceil.f32(float %p)
+ %18 = call float @llvm.trunc.f32(float %p)
ret void
}
-define void @binary(float %p1, float %p2) {
+define void @unary_f16(half %p) {
+entry:
+ %1 = call half @llvm.dx.saturate.f16(half %p)
+ %2 = call half @llvm.cos.f16(half %p)
+ %3 = call half @llvm.sin.f16(half %p)
+ %4 = call half @llvm.tan.f16(half %p)
+ %5 = call half @llvm.acos.f16(half %p)
+ %6 = call half @llvm.asin.f16(half %p)
+ %7 = call half @llvm.atan.f16(half %p)
+ %8 = call half @llvm.cosh.f16(half %p)
+ %9 = call half @llvm.sinh.f16(half %p)
+ %10 = call half @llvm.tanh.f16(half %p)
+ %11 = call half @llvm.exp2.f16(half %p)
+ %12 = call half @llvm.dx.frac.f16(half %p)
+ %13 = call half @llvm.log2.f16(half %p)
+ %14 = call half @llvm.sqrt.f16(half %p)
+ %15 = call half @llvm.roundeven.f16(half %p)
+ %16 = call half @llvm.floor.f16(half %p)
+ %17 = call half @llvm.ceil.f16(half %p)
+ %18 = call half @llvm.trunc.f16(half %p)
+ ret void
+}
+
+define void @unary_f64(double %p) {
+entry:
+ %1 = call double @llvm.dx.saturate.f64(double %p)
+ ret void
+}
+
+define void @binary_f32(float %p1, float %p2) {
entry:
%20 = call float @llvm.maxnum.f32(float %p1, float %p2)
%21 = call float @llvm.minnum.f32(float %p1, float %p2)
ret void
}
+define void @binary_f16(half %p1, half %p2) {
+entry:
+ %20 = call half @llvm.maxnum.f16(half %p1, half %p2)
+ %21 = call half @llvm.minnum.f16(half %p1, half %p2)
+ ret void
+}
+
+define void @binary_f64(double %p1, double %p2) {
+entry:
+ %20 = call double @llvm.maxnum.f64(double %p1, double %p2)
+ %21 = call double @llvm.minnum.f64(double %p1, double %p2)
+ ret void
+}
+
define void @tertiary(float %p1, float %p2, float %p3) {
entry:
%22 = call float @llvm.fmuladd.f32(float %p1, float %p2, float %p3)
@@ -55,26 +97,56 @@ entry:
ret void
}
-define void @dot2(<2 x float> %a, <2 x float> %b) {
+define void @dot2_f32(<2 x float> %a, <2 x float> %b) {
entry:
%24 = call float @llvm.dx.fdot.v2f32(<2 x float> %a, <2 x float> %b)
ret void
}
-define void @dot3(<3 x float> %a, <3 x float> %b) {
+define void @dot2_f16(<2 x half> %a, <2 x half> %b) {
+entry:
+ %24 = call half @llvm.dx.fdot.v2f16(<2 x half> %a, <2 x half> %b)
+ ret void
+}
+
+define void @dot3_f32(<3 x float> %a, <3 x float> %b) {
entry:
%25 = call float @llvm.dx.fdot.v3f32(<3 x float> %a, <3 x float> %b)
ret void
}
-define void @dot4(<4 x float> %a, <4 x float> %b) {
+define void @dot3_16(<3 x half> %a, <3 x half> %b) {
+entry:
+ %25 = call half @llvm.dx.fdot.v3f16(<3 x half> %a, <3 x half> %b)
+ ret void
+}
+
+define void @dot4_f32(<4 x float> %a, <4 x float> %b) {
entry:
%26 = call float @llvm.dx.fdot.v4f32(<4 x float> %a, <4 x float> %b)
ret void
}
-define void @wave_rla(float %expr, i32 %idx) {
+define void @dot4_f16(<4 x half> %a, <4 x half> %b) {
+entry:
+ %26 = call half @llvm.dx.fdot.v4f16(<4 x half> %a, <4 x half> %b)
+ ret void
+}
+
+define void @wave_rla_f32(float %expr, i32 %idx) {
entry:
%27 = call float @llvm.dx.wave.readlane(float %expr, i32 %idx)
ret void
}
+
+define void @wave_rla_f16(half %expr, i32 %idx) {
+entry:
+ %27 = call half @llvm.dx.wave.readlane(half %expr, i32 %idx)
+ ret void
+}
+
+define void @wave_rla_i32(i32 %expr, i32 %idx) {
+entry:
+ %27 = call i32 @llvm.dx.wave.readlane(i32 %expr, i32 %idx)
+ ret void
+}
More information about the cfe-commits
mailing list