[llvm] [DirectX] Add support to lower LLVM intrinsics ceil, cos, fabs, and floor to DXIL Ops. (PR #80350)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 1 14:00:00 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-directx
Author: S. Bharadwaj Yadavalli (bharadwajy)
<details>
<summary>Changes</summary>
- Updated DXIL.td for structure and format of DXIL Op overload type representation to match that in DXIL specification.
- Modified DXILEmitter accordingly.
- Added entries in DXIL.td to describe the new DXIL intrinsics.
- Added tests to verify correctness of lowering to the new DXIL intrinsics.
---
Full diff: https://github.com/llvm/llvm-project/pull/80350.diff
7 Files Affected:
- (modified) llvm/lib/Target/DirectX/DXIL.td (+89-41)
- (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+2)
- (added) llvm/test/CodeGen/DirectX/ceil.ll (+17)
- (added) llvm/test/CodeGen/DirectX/cos.ll (+17)
- (added) llvm/test/CodeGen/DirectX/fabs.ll (+25)
- (added) llvm/test/CodeGen/DirectX/floor.ll (+17)
- (modified) llvm/utils/TableGen/DXILEmitter.cpp (+31-18)
``````````diff
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 709279889653b..83cf038d0767a 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -13,23 +13,41 @@
include "llvm/IR/Intrinsics.td"
-class dxil_class<string _name> {
+// Abstract DXIL Op Class representation
+class DXILClass<string _name> {
string name = _name;
}
-class dxil_category<string _name> {
+
+// Abstract DXIL Op Category representation
+class DXILCategory<string _name> {
string name = _name;
}
-def Unary : dxil_class<"Unary">;
-def Binary : dxil_class<"Binary">;
-def FlattenedThreadIdInGroupClass : dxil_class<"FlattenedThreadIdInGroup">;
-def ThreadIdInGroupClass : dxil_class<"ThreadIdInGroup">;
-def ThreadIdClass : dxil_class<"ThreadId">;
-def GroupIdClass : dxil_class<"GroupId">;
+// Abstract Type representation
+class Type<string _desc> {
+ string desc = _desc;
+}
+
+// Basic Types
+def float_t : Type<"float">;
+def rfloat_t : Type<"float with rounding">;
+def int_t : Type<"int">;
+def uint_t : Type<"unit">;
+def uint_cb_t : Type<"uint with carry or borrow">;
+def uint_two_outs_t : Type<"uint with two outputs">;
+
+// DXIL Op Classes
+def unary : DXILClass<"Unary">;
+def binary : DXILClass<"Binary">;
+def flattened_threadIdIn_group : DXILClass<"FlattenedThreadIdInGroup">;
+def threadId_in_group : DXILClass<"ThreadIdInGroup">;
+def thread_id : DXILClass<"ThreadId">;
+def group_id : DXILClass<"GroupId">;
-def binary_uint : dxil_category<"Binary uint">;
-def unary_float : dxil_category<"Unary float">;
-def ComputeID : dxil_category<"Compute/Mesh/Amplification shader">;
+// DXIL Op Categories
+def binary_uint : DXILCategory<"Binary uint">;
+def unary_float : DXILCategory<"Unary float">;
+def compute_id : DXILCategory<"Compute/Mesh/Amplification shader">;
// The parameter description for a DXIL instruction
@@ -49,13 +67,14 @@ class dxil_param<int _pos, string type, string _name, string _doc,
}
// A representation for a DXIL instruction
-class dxil_inst<string _name> {
- string name = _name; // short, unique name
+class dxil_inst {
+ // TODO : Appears redundant. dxil_op should serve the same purpose
+ string name = ""; // short, unique name
string dxil_op = ""; // name of DXIL operation
int dxil_opid = 0; // ID of DXIL operation
- dxil_class op_class; // name of the opcode class
- dxil_category category; // classification for this instruction
+ DXILClass op_class; // name of the opcode class
+ DXILCategory category; // classification for this instruction
string doc = ""; // the documentation description of this instruction
list<dxil_param> ops = []; // the operands that this instruction takes
string oload_types = ""; // overload types if applicable
@@ -72,10 +91,11 @@ class dxil_inst<string _name> {
list<string> stats_group = [];
}
-class dxil_op<string name, int code_id, dxil_class code_class, dxil_category op_category, string _doc,
+class dxil_op<string _name, int code_id, DXILClass code_class, DXILCategory op_category, string _doc,
string _oload_types, string _fn_attr, list<dxil_param> op_params,
- list<string> _stats_group = []> : dxil_inst<name> {
- let dxil_op = name;
+ list<string> _stats_group = []> : dxil_inst {
+ let name = _name;
+ let dxil_op = _name;
let dxil_opid = code_id;
let doc = _doc;
let ops = op_params;
@@ -86,31 +106,59 @@ class dxil_op<string name, int code_id, dxil_class code_class, dxil_category op_
let stats_group = _stats_group;
}
-// The intrinsic which map directly to this dxil op.
+class dxil_op_min<string _name, int code_id, string _doc> : dxil_inst {
+ let name = _name;
+ let dxil_op = _name;
+ let dxil_opid = code_id;
+ let doc = _doc;
+}
+
+
+// LLVM intrinsic to map dxil op to.
class dxil_map_intrinsic<Intrinsic llvm_intrinsic_> { Intrinsic llvm_intrinsic = llvm_intrinsic_; }
-def Sin : dxil_op<"Sin", 13, Unary, unary_float, "returns sine(theta) for theta in radians.",
- "half;float;", "rn",
- [
- dxil_param<0, "$o", "", "operation result">,
- dxil_param<1, "i32", "opcode", "DXIL opcode">,
- dxil_param<2, "$o", "value", "input value">
- ],
- ["floats"]>,
- dxil_map_intrinsic<int_sin>;
-
-def UMax :dxil_op< "UMax", 39, Binary, binary_uint, "unsigned integer maximum. UMax(a,b) = a > b ? a : b",
- "i16;i32;i64;", "rn",
- [
+// Unary class DXIL Ops
+let op_class = unary, ops = [
+ dxil_param<0, "$o", "", "operation result">,
+ dxil_param<1, "i32", "opcode", "DXIL opcode">,
+ dxil_param<2, "$o", "value", "input value">
+ ], fn_attr = "rn" in {
+ // Unary half/float DXIL Ops
+ let category = unary_float, oload_types = "hf", stats_group = ["floats"] in {
+ def Cos : dxil_op_min<"Cos", 12, "returns cosine(theta) for theta in radians.">,
+ dxil_map_intrinsic<int_cos>;
+
+ def Sin : dxil_op_min<"Sin", 13, "returns sine(theta) for theta in radians.">,
+ dxil_map_intrinsic<int_sin>;
+
+ def Round_ni : dxil_op_min<"Round_ni", 27, "floating-point round to integral float towards -INF, commonly known as floor().">,
+ dxil_map_intrinsic<int_floor>;
+
+ def Round_pi : dxil_op_min<"Round_pi", 28, "floating-point round to integral float towards +INF, commonly known as ceil().">,
+ dxil_map_intrinsic<int_ceil>;
+ }
+ // Unary half/float/double DXIL Ops
+ let category = unary_float, oload_types = "hfd", stats_group = ["floats"] in {
+ def FAbs : dxil_op_min<"FAbs", 6, "returns the absolute value of the input value.">,
+ dxil_map_intrinsic<int_fabs>;
+ }
+}
+
+// Binary class DXIL Ops
+let op_class = binary, ops = [
dxil_param<0, "$o", "", "operation result">,
dxil_param<1, "i32", "opcode", "DXIL opcode">,
dxil_param<2, "$o", "a", "input value">,
dxil_param<3, "$o", "b", "input value">
- ],
- ["uints"]>,
- dxil_map_intrinsic<int_umax>;
+ ], fn_attr = "rn" in
+// Binary uint DXIL Ops
+ let category = binary_uint, oload_types = "wil", stats_group = ["uints"] in
+ def UMax : dxil_op_min<"UMax", 39, "unsigned integer maximum. UMax(a,b) = a > b ? a : b">,
+ dxil_map_intrinsic<int_umax>;
+
+// Compute ID class DXIL Ops
-def ThreadId :dxil_op< "ThreadId", 93, ThreadIdClass, ComputeID, "reads the thread ID", "i32;", "rn",
+def ThreadId :dxil_op< "ThreadId", 93, thread_id, compute_id, "reads the thread ID", "i", "rn",
[
dxil_param<0, "i32", "", "thread ID component">,
dxil_param<1, "i32", "opcode", "DXIL opcode">,
@@ -118,7 +166,7 @@ def ThreadId :dxil_op< "ThreadId", 93, ThreadIdClass, ComputeID, "reads the thr
]>,
dxil_map_intrinsic<int_dx_thread_id>;
-def GroupId :dxil_op< "GroupId", 94, GroupIdClass, ComputeID, "reads the group ID (SV_GroupID)", "i32;", "rn",
+def GroupId :dxil_op< "GroupId", 94, group_id, compute_id, "reads the group ID (SV_GroupID)", "i", "rn",
[
dxil_param<0, "i32", "", "group ID component">,
dxil_param<1, "i32", "opcode", "DXIL opcode">,
@@ -126,8 +174,8 @@ def GroupId :dxil_op< "GroupId", 94, GroupIdClass, ComputeID, "reads the group
]>,
dxil_map_intrinsic<int_dx_group_id>;
-def ThreadIdInGroup :dxil_op< "ThreadIdInGroup", 95, ThreadIdInGroupClass, ComputeID,
- "reads the thread ID within the group (SV_GroupThreadID)", "i32;", "rn",
+def ThreadIdInGroup :dxil_op< "ThreadIdInGroup", 95, threadId_in_group, compute_id,
+ "reads the thread ID within the group (SV_GroupThreadID)", "i", "rn",
[
dxil_param<0, "i32", "", "thread ID in group component">,
dxil_param<1, "i32", "opcode", "DXIL opcode">,
@@ -135,8 +183,8 @@ def ThreadIdInGroup :dxil_op< "ThreadIdInGroup", 95, ThreadIdInGroupClass, Comp
]>,
dxil_map_intrinsic<int_dx_thread_id_in_group>;
-def FlattenedThreadIdInGroup :dxil_op< "FlattenedThreadIdInGroup", 96, FlattenedThreadIdInGroupClass, ComputeID,
- "provides a flattened index for a given thread within a given group (SV_GroupIndex)", "i32;", "rn",
+def FlattenedThreadIdInGroup :dxil_op< "FlattenedThreadIdInGroup", 96, flattened_threadIdIn_group, compute_id,
+ "provides a flattened index for a given thread within a given group (SV_GroupIndex)", "i", "rn",
[
dxil_param<0, "i32", "", "result">,
dxil_param<1, "i32", "opcode", "DXIL opcode">
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index f6e2297e9af41..cbfd65e27983f 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -45,6 +45,8 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
Args.append(CI->arg_begin(), CI->arg_end());
B.SetInsertPoint(CI);
CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, OverloadTy, CI->args());
+ // Retain tail call property
+ DXILCI->setTailCall(CI->isTailCall());
CI->replaceAllUsesWith(DXILCI);
CI->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/ceil.ll b/llvm/test/CodeGen/DirectX/ceil.ll
new file mode 100644
index 0000000000000..29cfa52e50215
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ceil.ll
@@ -0,0 +1,17 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for ceil are generated for half and float.
+
+define noundef half @test_ceil_half(half noundef %a) #0 {
+entry:
+ ; CHECK: call half @dx.op.unary.f16(i32 28, half %{{.*}})
+ %1 = call half @llvm.ceil.f16(half %a)
+ ret half %1
+}
+
+define noundef float @test_ceil_float(float noundef %a) #0 {
+entry:
+ ; CHECK: call float @dx.op.unary.f32(i32 28, float %{{.*}})
+ %1 = call float @llvm.ceil.f32(float %a)
+ ret float %1
+}
diff --git a/llvm/test/CodeGen/DirectX/cos.ll b/llvm/test/CodeGen/DirectX/cos.ll
new file mode 100644
index 0000000000000..f83b24cb38991
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/cos.ll
@@ -0,0 +1,17 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for cos are generated for float and half.
+
+define noundef half @test_cos_half(half noundef %a) #0 {
+entry:
+ ; CHECK: call half @dx.op.unary.f16(i32 12, half %{{.*}})
+ %1 = call half @llvm.cos.f16(half %a)
+ ret half %1
+}
+
+define noundef float @test_cos_float(float noundef %a) #0 {
+entry:
+ ; CHECK: call float @dx.op.unary.f32(i32 12, float %{{.*}})
+ %1 = call float @llvm.cos.f32(float %a)
+ ret float %1
+}
diff --git a/llvm/test/CodeGen/DirectX/fabs.ll b/llvm/test/CodeGen/DirectX/fabs.ll
new file mode 100644
index 0000000000000..25c3c6ed8a47d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/fabs.ll
@@ -0,0 +1,25 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for fabs are generated
+; for half, float and double.
+
+define noundef half @test_fabs_half(half noundef %a) #0 {
+entry:
+ ; CHECK: call half @dx.op.unary.f16(i32 6, half %{{.*}})
+ %1 = call half @llvm.fabs.f16(half %a)
+ ret half %1
+}
+
+define noundef float @test_fabs_float(float noundef %a) #0 {
+entry:
+ ; CHECK: call float @dx.op.unary.f32(i32 6, float %{{.*}})
+ %1 = call float @llvm.fabs.f32(float %a)
+ ret float %1
+}
+
+define noundef double @test_fabs_double(double noundef %a) #0 {
+entry:
+ ; CHECK: call double @dx.op.unary.f64(i32 6, double %{{.*}})
+ %1 = call double @llvm.fabs.f64(double %a)
+ ret double %1
+}
diff --git a/llvm/test/CodeGen/DirectX/floor.ll b/llvm/test/CodeGen/DirectX/floor.ll
new file mode 100644
index 0000000000000..df69804513f8d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/floor.ll
@@ -0,0 +1,17 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for floor are generated for half and float.
+
+define noundef half @test_floor_half(half noundef %a) #0 {
+entry:
+ ; CHECK: call half @dx.op.unary.f16(i32 27, half %{{.*}})
+ %1 = call half @llvm.floor.f16(half %a)
+ ret half %1
+}
+
+define noundef float @test_floor_float(float noundef %a) #0 {
+entry:
+ ; CHECK: call float @dx.op.unary.f32(i32 27, float %{{.*}})
+ %1 = call float @llvm.floor.f32(float %a)
+ ret float %1
+}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index ddc7cfb813447..0a53f3156205a 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -19,6 +19,7 @@
#include "llvm/Support/DXILABI.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
+#include <algorithm>
using namespace llvm;
using namespace llvm::dxil;
@@ -277,30 +278,42 @@ static std::string emitDXILOperationFnAttr(StringRef FnAttr) {
.Default("Attribute::None");
}
-static std::string getOverloadKind(StringRef Overload) {
- return StringSwitch<std::string>(Overload)
- .Case("half", "OverloadKind::HALF")
- .Case("float", "OverloadKind::FLOAT")
- .Case("double", "OverloadKind::DOUBLE")
- .Case("i1", "OverloadKind::I1")
- .Case("i16", "OverloadKind::I16")
- .Case("i32", "OverloadKind::I32")
- .Case("i64", "OverloadKind::I64")
- .Case("udt", "OverloadKind::UserDefineType")
- .Case("obj", "OverloadKind::ObjectType")
- .Default("OverloadKind::VOID");
+// Convert overload type notation as specified in DXIL
+// specification.
+static std::string getOverloadKind(const char& Overload) {
+ switch (Overload) {
+ case 'h' :
+ return "OverloadKind::HALF";
+ case 'f' :
+ return "OverloadKind::FLOAT";
+ case 'd' :
+ return "OverloadKind::DOUBLE";
+ case '1' :
+ return "OverloadKind::I1";
+ case 'w' :
+ return "OverloadKind::I16";
+ case 'i' :
+ return "OverloadKind::I32";
+ case 'l' :
+ return "OverloadKind::I64";
+ case 'u' :
+ return "OverloadKind::UserDefinedType";
+ case 'o' :
+ return "OverloadKind::ObjectType";
+ case 'v' :
+ return "OverloadKind::VOID";
+ default:
+ llvm_unreachable("Unknown overload kind specified");
+ }
}
static std::string getDXILOperationOverload(StringRef Overloads) {
- SmallVector<StringRef> OverloadStrs;
- Overloads.split(OverloadStrs, ';', /*MaxSplit*/ -1, /*KeepEmpty*/ false);
// Format is: OverloadKind::FLOAT | OverloadKind::HALF
- assert(!OverloadStrs.empty() && "Invalid overloads");
- auto It = OverloadStrs.begin();
+ auto It = Overloads.begin();
std::string Result;
raw_string_ostream OS(Result);
OS << getOverloadKind(*It);
- for (++It; It != OverloadStrs.end(); ++It) {
+ for (++It; It != Overloads.end(); ++It) {
OS << " | " << getOverloadKind(*It);
}
return OS.str();
@@ -431,7 +444,7 @@ static void emitDXILOperationTable(std::vector<DXILOperationData> &DXILOps,
}
static void EmitDXILOperation(RecordKeeper &Records, raw_ostream &OS) {
- std::vector<Record *> Ops = Records.getAllDerivedDefinitions("dxil_op");
+ std::vector<Record *> Ops = Records.getAllDerivedDefinitions("dxil_inst");
OS << "// Generated code, do not edit.\n";
OS << "\n";
``````````
</details>
https://github.com/llvm/llvm-project/pull/80350
More information about the llvm-commits
mailing list