[llvm] [DirectX] Add support to lower LLVM intrinsics ceil, cos, fabs, and floor to DXIL Ops. (PR #80350)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 1 13:59:08 PST 2024


https://github.com/bharadwajy created https://github.com/llvm/llvm-project/pull/80350

- Updated DXIL.td for structure and format of DXIL Op overload type representation to match that in DXIL specification.
- Modified DXILEmitter accordingly.
- Added entries in DXIL.td to describe the new DXIL intrinsics.
- Added tests to verify correctness of lowering to the new DXIL intrinsics.

>From 183418d89b2ead15607dde88bc984a2315255d73 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Thu, 1 Feb 2024 10:59:41 -0500
Subject: [PATCH 1/4] Add DXIL Op description for LLVM intrinsic cos. Add test
 to verify lowering of cos. Restructure DXIL.td.

---
 llvm/lib/Target/DirectX/DXIL.td            | 114 ++++++++++++++-------
 llvm/lib/Target/DirectX/DXILOpLowering.cpp |   2 +
 llvm/test/CodeGen/DirectX/cos.ll           |  17 +++
 llvm/utils/TableGen/DXILEmitter.cpp        |   3 +-
 4 files changed, 96 insertions(+), 40 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/cos.ll

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 709279889653b..d1b7a7ef4e252 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -13,23 +13,41 @@
 
 include "llvm/IR/Intrinsics.td"
 
-class dxil_class<string _name> {
+// Abstract DXIL Op Class representation
+class DXILClass<string _name> {
   string name = _name;
 }
-class dxil_category<string _name> {
+
+// Abstract DXIL Op Category representation
+class DXILCategory<string _name> {
   string name = _name;
 }
 
-def Unary : dxil_class<"Unary">;
-def Binary : dxil_class<"Binary">;
-def FlattenedThreadIdInGroupClass : dxil_class<"FlattenedThreadIdInGroup">;
-def ThreadIdInGroupClass : dxil_class<"ThreadIdInGroup">;
-def ThreadIdClass : dxil_class<"ThreadId">;
-def GroupIdClass : dxil_class<"GroupId">;
+// Abstract Type representation
+class Type<string _desc> {
+  string desc = _desc;
+}
+
+// Basic Types
+def float_t : Type<"float">;
+def rfloat_t : Type<"float with rounding">;
+def int_t : Type<"int">;
+def uint_t : Type<"unit">;
+def uint_cb_t : Type<"uint with carry or borrow">;
+def uint_two_outs_t : Type<"uint with two outputs">;
 
-def binary_uint : dxil_category<"Binary uint">;
-def unary_float : dxil_category<"Unary float">;
-def ComputeID : dxil_category<"Compute/Mesh/Amplification shader">;
+// DXIL Op Classes
+def unary : DXILClass<"Unary">;
+def binary : DXILClass<"Binary">;
+def flattened_threadIdIn_group : DXILClass<"FlattenedThreadIdInGroup">;
+def threadId_in_group : DXILClass<"ThreadIdInGroup">;
+def thread_id : DXILClass<"ThreadId">;
+def group_id : DXILClass<"GroupId">;
+
+// DXIL Op Categories
+def binary_uint : DXILCategory<"Binary uint">;
+def unary_float : DXILCategory<"Unary float">;
+def compute_id : DXILCategory<"Compute/Mesh/Amplification shader">;
 
 
 // The parameter description for a DXIL instruction
@@ -49,13 +67,14 @@ class dxil_param<int _pos, string type, string _name, string _doc,
 }
 
 // A representation for a DXIL instruction
-class dxil_inst<string _name> {
-  string name = _name; // short, unique name
+class dxil_inst {
+  // TODO : Appears redundant. dxil_op should serve the same purpose
+  string name = ""; // short, unique name
 
   string dxil_op = "";       // name of DXIL operation
   int dxil_opid = 0;         // ID of DXIL operation
-  dxil_class  op_class;      // name of the opcode class
-  dxil_category category;    // classification for this instruction
+  DXILClass  op_class;      // name of the opcode class
+  DXILCategory category;    // classification for this instruction
   string doc = "";           // the documentation description of this instruction
   list<dxil_param> ops = []; // the operands that this instruction takes
   string oload_types = "";   // overload types if applicable
@@ -72,10 +91,11 @@ class dxil_inst<string _name> {
   list<string> stats_group = [];
 }
 
-class dxil_op<string name, int code_id, dxil_class code_class, dxil_category op_category, string _doc,
+class dxil_op<string _name, int code_id, DXILClass code_class, DXILCategory op_category, string _doc,
               string _oload_types, string _fn_attr, list<dxil_param> op_params,
-              list<string> _stats_group = []> : dxil_inst<name> {
-  let dxil_op = name;
+              list<string> _stats_group = []> : dxil_inst {
+  let name = _name;
+  let dxil_op = _name;
   let dxil_opid = code_id;
   let doc = _doc;
   let ops = op_params;
@@ -86,31 +106,47 @@ class dxil_op<string name, int code_id, dxil_class code_class, dxil_category op_
   let stats_group = _stats_group;
 }
 
-// The intrinsic which map directly to this dxil op.
+class dxil_op_min<string _name, int code_id, string _doc> : dxil_inst {
+  let name = _name;
+  let dxil_op = _name;
+  let dxil_opid = code_id;
+  let doc = _doc;
+}
+
+
+// LLVM intrinsic to map dxil op to.
 class dxil_map_intrinsic<Intrinsic llvm_intrinsic_> { Intrinsic llvm_intrinsic = llvm_intrinsic_; }
 
-def Sin : dxil_op<"Sin", 13, Unary, unary_float, "returns sine(theta) for theta in radians.",
-  "half;float;", "rn",
-  [
-    dxil_param<0, "$o", "", "operation result">,
-    dxil_param<1, "i32", "opcode", "DXIL opcode">,
-    dxil_param<2, "$o", "value", "input value">
-  ],
-  ["floats"]>,
-  dxil_map_intrinsic<int_sin>;
-
-def UMax :dxil_op< "UMax", 39,  Binary,  binary_uint, "unsigned integer maximum. UMax(a,b) = a > b ? a : b",
-    "i16;i32;i64;",  "rn",
-  [
+// Unary class DXIL Ops
+let op_class = unary, ops = [
+      dxil_param<0, "$o", "", "operation result">,
+      dxil_param<1, "i32", "opcode", "DXIL opcode">,
+      dxil_param<2, "$o", "value", "input value">
+    ], fn_attr = "rn" in
+// Unary float DXIL Ops
+  let category = unary_float, oload_types = "half;float;", stats_group = ["floats"] in {
+    def Cos : dxil_op_min<"Cos", 12, "returns cosine(theta) for theta in radians.">,
+              dxil_map_intrinsic<int_cos>;
+
+    def Sin : dxil_op_min<"Sin", 13, "returns sine(theta) for theta in radians.">,
+              dxil_map_intrinsic<int_sin>;
+  }
+
+// Binary class DXIL Ops
+let op_class = binary, ops = [
     dxil_param<0,  "$o",  "",  "operation result">,
     dxil_param<1,  "i32",  "opcode",  "DXIL opcode">,
     dxil_param<2,  "$o",  "a",  "input value">,
     dxil_param<3,  "$o",  "b",  "input value">
-  ],
-  ["uints"]>,
-  dxil_map_intrinsic<int_umax>;
+  ], fn_attr = "rn" in
+// Binary uint DXIL Ops
+  let category = binary_uint, oload_types = "i16;i32;i64;", stats_group = ["uints"] in
+    def UMax : dxil_op_min<"UMax", 39, "unsigned integer maximum. UMax(a,b) = a > b ? a : b">,
+                   dxil_map_intrinsic<int_umax>;
+
+// Compute ID class DXIL Ops
 
-def ThreadId :dxil_op< "ThreadId", 93,  ThreadIdClass, ComputeID, "reads the thread ID", "i32;",  "rn",
+def ThreadId :dxil_op< "ThreadId", 93,  thread_id, compute_id, "reads the thread ID", "i32;",  "rn",
   [
     dxil_param<0,  "i32",  "",  "thread ID component">,
     dxil_param<1,  "i32",  "opcode",  "DXIL opcode">,
@@ -118,7 +154,7 @@ def ThreadId :dxil_op< "ThreadId", 93,  ThreadIdClass, ComputeID, "reads the thr
   ]>,
   dxil_map_intrinsic<int_dx_thread_id>;
 
-def GroupId :dxil_op< "GroupId", 94,  GroupIdClass, ComputeID, "reads the group ID (SV_GroupID)", "i32;",  "rn",
+def GroupId :dxil_op< "GroupId", 94,  group_id, compute_id, "reads the group ID (SV_GroupID)", "i32;",  "rn",
   [
     dxil_param<0,  "i32",  "",  "group ID component">,
     dxil_param<1,  "i32",  "opcode",  "DXIL opcode">,
@@ -126,7 +162,7 @@ def GroupId :dxil_op< "GroupId", 94,  GroupIdClass, ComputeID, "reads the group
   ]>,
   dxil_map_intrinsic<int_dx_group_id>;
 
-def ThreadIdInGroup :dxil_op< "ThreadIdInGroup", 95,  ThreadIdInGroupClass, ComputeID,
+def ThreadIdInGroup :dxil_op< "ThreadIdInGroup", 95,  threadId_in_group, compute_id,
   "reads the thread ID within the group (SV_GroupThreadID)", "i32;",  "rn",
   [
     dxil_param<0,  "i32",  "",  "thread ID in group component">,
@@ -135,7 +171,7 @@ def ThreadIdInGroup :dxil_op< "ThreadIdInGroup", 95,  ThreadIdInGroupClass, Comp
   ]>,
   dxil_map_intrinsic<int_dx_thread_id_in_group>;
 
-def FlattenedThreadIdInGroup :dxil_op< "FlattenedThreadIdInGroup", 96,  FlattenedThreadIdInGroupClass, ComputeID,
+def FlattenedThreadIdInGroup :dxil_op< "FlattenedThreadIdInGroup", 96,  flattened_threadIdIn_group, compute_id,
    "provides a flattened index for a given thread within a given group (SV_GroupIndex)", "i32;",  "rn",
   [
     dxil_param<0,  "i32",  "",  "result">,
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index f6e2297e9af41..cbfd65e27983f 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -45,6 +45,8 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
     Args.append(CI->arg_begin(), CI->arg_end());
     B.SetInsertPoint(CI);
     CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, OverloadTy, CI->args());
+    // Retain tail call property
+    DXILCI->setTailCall(CI->isTailCall());
 
     CI->replaceAllUsesWith(DXILCI);
     CI->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/cos.ll b/llvm/test/CodeGen/DirectX/cos.ll
new file mode 100644
index 0000000000000..f83b24cb38991
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/cos.ll
@@ -0,0 +1,17 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for cos are generated for float and half.
+
+define noundef half @test_cos_half(half noundef %a) #0 {
+entry:
+  ; CHECK: call half @dx.op.unary.f16(i32 12, half %{{.*}})
+  %1 = call half @llvm.cos.f16(half %a)
+  ret half %1
+}
+
+define noundef float @test_cos_float(float noundef %a) #0 {
+entry:
+  ; CHECK: call float @dx.op.unary.f32(i32 12, float %{{.*}})
+  %1 = call float @llvm.cos.f32(float %a)
+  ret float %1
+}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index ddc7cfb813447..dbf0a7aeaea17 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -19,6 +19,7 @@
 #include "llvm/Support/DXILABI.h"
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
+#include <algorithm>
 
 using namespace llvm;
 using namespace llvm::dxil;
@@ -431,7 +432,7 @@ static void emitDXILOperationTable(std::vector<DXILOperationData> &DXILOps,
 }
 
 static void EmitDXILOperation(RecordKeeper &Records, raw_ostream &OS) {
-  std::vector<Record *> Ops = Records.getAllDerivedDefinitions("dxil_op");
+  std::vector<Record *> Ops = Records.getAllDerivedDefinitions("dxil_inst");
   OS << "// Generated code, do not edit.\n";
   OS << "\n";
 

>From 8b7c228ee2ef53ff2730761db3cde9917caa1d30 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Thu, 1 Feb 2024 14:59:47 -0500
Subject: [PATCH 2/4] Change representation of DXIL Op overload type to match
 that in DXIL Specification.

---
 llvm/lib/Target/DirectX/DXIL.td     | 12 ++++----
 llvm/utils/TableGen/DXILEmitter.cpp | 46 ++++++++++++++++++-----------
 2 files changed, 35 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index d1b7a7ef4e252..e6e79e6bbe1c5 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -124,7 +124,7 @@ let op_class = unary, ops = [
       dxil_param<2, "$o", "value", "input value">
     ], fn_attr = "rn" in
 // Unary float DXIL Ops
-  let category = unary_float, oload_types = "half;float;", stats_group = ["floats"] in {
+  let category = unary_float, oload_types = "hf", stats_group = ["floats"] in {
     def Cos : dxil_op_min<"Cos", 12, "returns cosine(theta) for theta in radians.">,
               dxil_map_intrinsic<int_cos>;
 
@@ -140,13 +140,13 @@ let op_class = binary, ops = [
     dxil_param<3,  "$o",  "b",  "input value">
   ], fn_attr = "rn" in
 // Binary uint DXIL Ops
-  let category = binary_uint, oload_types = "i16;i32;i64;", stats_group = ["uints"] in
+  let category = binary_uint, oload_types = "wil", stats_group = ["uints"] in
     def UMax : dxil_op_min<"UMax", 39, "unsigned integer maximum. UMax(a,b) = a > b ? a : b">,
                    dxil_map_intrinsic<int_umax>;
 
 // Compute ID class DXIL Ops
 
-def ThreadId :dxil_op< "ThreadId", 93,  thread_id, compute_id, "reads the thread ID", "i32;",  "rn",
+def ThreadId :dxil_op< "ThreadId", 93,  thread_id, compute_id, "reads the thread ID", "i",  "rn",
   [
     dxil_param<0,  "i32",  "",  "thread ID component">,
     dxil_param<1,  "i32",  "opcode",  "DXIL opcode">,
@@ -154,7 +154,7 @@ def ThreadId :dxil_op< "ThreadId", 93,  thread_id, compute_id, "reads the thread
   ]>,
   dxil_map_intrinsic<int_dx_thread_id>;
 
-def GroupId :dxil_op< "GroupId", 94,  group_id, compute_id, "reads the group ID (SV_GroupID)", "i32;",  "rn",
+def GroupId :dxil_op< "GroupId", 94,  group_id, compute_id, "reads the group ID (SV_GroupID)", "i",  "rn",
   [
     dxil_param<0,  "i32",  "",  "group ID component">,
     dxil_param<1,  "i32",  "opcode",  "DXIL opcode">,
@@ -163,7 +163,7 @@ def GroupId :dxil_op< "GroupId", 94,  group_id, compute_id, "reads the group ID
   dxil_map_intrinsic<int_dx_group_id>;
 
 def ThreadIdInGroup :dxil_op< "ThreadIdInGroup", 95,  threadId_in_group, compute_id,
-  "reads the thread ID within the group (SV_GroupThreadID)", "i32;",  "rn",
+  "reads the thread ID within the group (SV_GroupThreadID)", "i",  "rn",
   [
     dxil_param<0,  "i32",  "",  "thread ID in group component">,
     dxil_param<1,  "i32",  "opcode",  "DXIL opcode">,
@@ -172,7 +172,7 @@ def ThreadIdInGroup :dxil_op< "ThreadIdInGroup", 95,  threadId_in_group, compute
   dxil_map_intrinsic<int_dx_thread_id_in_group>;
 
 def FlattenedThreadIdInGroup :dxil_op< "FlattenedThreadIdInGroup", 96,  flattened_threadIdIn_group, compute_id,
-   "provides a flattened index for a given thread within a given group (SV_GroupIndex)", "i32;",  "rn",
+   "provides a flattened index for a given thread within a given group (SV_GroupIndex)", "i",  "rn",
   [
     dxil_param<0,  "i32",  "",  "result">,
     dxil_param<1,  "i32",  "opcode",  "DXIL opcode">
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index dbf0a7aeaea17..0a53f3156205a 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -278,30 +278,42 @@ static std::string emitDXILOperationFnAttr(StringRef FnAttr) {
       .Default("Attribute::None");
 }
 
-static std::string getOverloadKind(StringRef Overload) {
-  return StringSwitch<std::string>(Overload)
-      .Case("half", "OverloadKind::HALF")
-      .Case("float", "OverloadKind::FLOAT")
-      .Case("double", "OverloadKind::DOUBLE")
-      .Case("i1", "OverloadKind::I1")
-      .Case("i16", "OverloadKind::I16")
-      .Case("i32", "OverloadKind::I32")
-      .Case("i64", "OverloadKind::I64")
-      .Case("udt", "OverloadKind::UserDefineType")
-      .Case("obj", "OverloadKind::ObjectType")
-      .Default("OverloadKind::VOID");
+// Convert overload type notation as specified in DXIL
+// specification.
+static std::string getOverloadKind(const char& Overload) {
+  switch (Overload) {
+      case 'h' :
+        return "OverloadKind::HALF";
+      case 'f' :
+        return "OverloadKind::FLOAT";
+      case 'd' :
+        return "OverloadKind::DOUBLE";
+      case '1' :
+        return "OverloadKind::I1";
+      case 'w' :
+        return "OverloadKind::I16";
+      case 'i' :
+        return "OverloadKind::I32";
+      case 'l' :
+        return "OverloadKind::I64";
+      case 'u' :
+        return "OverloadKind::UserDefinedType";
+      case 'o' :
+        return "OverloadKind::ObjectType";
+      case 'v' :
+        return "OverloadKind::VOID";
+      default:
+        llvm_unreachable("Unknown overload kind specified");
+  }
 }
 
 static std::string getDXILOperationOverload(StringRef Overloads) {
-  SmallVector<StringRef> OverloadStrs;
-  Overloads.split(OverloadStrs, ';', /*MaxSplit*/ -1, /*KeepEmpty*/ false);
   // Format is: OverloadKind::FLOAT | OverloadKind::HALF
-  assert(!OverloadStrs.empty() && "Invalid overloads");
-  auto It = OverloadStrs.begin();
+  auto It = Overloads.begin();
   std::string Result;
   raw_string_ostream OS(Result);
   OS << getOverloadKind(*It);
-  for (++It; It != OverloadStrs.end(); ++It) {
+  for (++It; It != Overloads.end(); ++It) {
     OS << " | " << getOverloadKind(*It);
   }
   return OS.str();

>From 89d82144dbbf8082ffa00172ca90ebba4f1096c3 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Thu, 1 Feb 2024 15:36:53 -0500
Subject: [PATCH 3/4] Add DXIL Op description for LLVM intrinsic fabs. Add test
 to verify lowering of fabs.

---
 llvm/lib/Target/DirectX/DXIL.td | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index e6e79e6bbe1c5..c31ba11cb0253 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -122,7 +122,7 @@ let op_class = unary, ops = [
       dxil_param<0, "$o", "", "operation result">,
       dxil_param<1, "i32", "opcode", "DXIL opcode">,
       dxil_param<2, "$o", "value", "input value">
-    ], fn_attr = "rn" in
+    ], fn_attr = "rn" in {
 // Unary float DXIL Ops
   let category = unary_float, oload_types = "hf", stats_group = ["floats"] in {
     def Cos : dxil_op_min<"Cos", 12, "returns cosine(theta) for theta in radians.">,
@@ -131,6 +131,12 @@ let op_class = unary, ops = [
     def Sin : dxil_op_min<"Sin", 13, "returns sine(theta) for theta in radians.">,
               dxil_map_intrinsic<int_sin>;
   }
+  // oload_types are different from above
+  let category = unary_float, oload_types = "hfd", stats_group = ["floats"] in {
+    def FAbs : dxil_op_min<"FAbs", 6, "returns the absolute value of the input value.">,
+               dxil_map_intrinsic<int_fabs>;
+  }
+}
 
 // Binary class DXIL Ops
 let op_class = binary, ops = [

>From 51ee896ead097e899bacfe3e6f6fd4cebc1fceff Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Thu, 1 Feb 2024 15:57:47 -0500
Subject: [PATCH 4/4] Add DXIL Op description for LLVM intrinsic floor and
 ceil. Add test to verify lowering of floor and ceil.

---
 llvm/lib/Target/DirectX/DXIL.td    | 10 ++++++++--
 llvm/test/CodeGen/DirectX/ceil.ll  | 17 +++++++++++++++++
 llvm/test/CodeGen/DirectX/fabs.ll  | 25 +++++++++++++++++++++++++
 llvm/test/CodeGen/DirectX/floor.ll | 17 +++++++++++++++++
 4 files changed, 67 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/ceil.ll
 create mode 100644 llvm/test/CodeGen/DirectX/fabs.ll
 create mode 100644 llvm/test/CodeGen/DirectX/floor.ll

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index c31ba11cb0253..83cf038d0767a 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -123,15 +123,21 @@ let op_class = unary, ops = [
       dxil_param<1, "i32", "opcode", "DXIL opcode">,
       dxil_param<2, "$o", "value", "input value">
     ], fn_attr = "rn" in {
-// Unary float DXIL Ops
+  // Unary half/float DXIL Ops
   let category = unary_float, oload_types = "hf", stats_group = ["floats"] in {
     def Cos : dxil_op_min<"Cos", 12, "returns cosine(theta) for theta in radians.">,
               dxil_map_intrinsic<int_cos>;
 
     def Sin : dxil_op_min<"Sin", 13, "returns sine(theta) for theta in radians.">,
               dxil_map_intrinsic<int_sin>;
+
+    def Round_ni : dxil_op_min<"Round_ni", 27, "floating-point round to integral float towards -INF, commonly known as floor().">,
+              dxil_map_intrinsic<int_floor>;
+
+    def Round_pi : dxil_op_min<"Round_pi", 28, "floating-point round to integral float towards +INF, commonly known as ceil().">,
+              dxil_map_intrinsic<int_ceil>;
   }
-  // oload_types are different from above
+  // Unary half/float/double DXIL Ops
   let category = unary_float, oload_types = "hfd", stats_group = ["floats"] in {
     def FAbs : dxil_op_min<"FAbs", 6, "returns the absolute value of the input value.">,
                dxil_map_intrinsic<int_fabs>;
diff --git a/llvm/test/CodeGen/DirectX/ceil.ll b/llvm/test/CodeGen/DirectX/ceil.ll
new file mode 100644
index 0000000000000..29cfa52e50215
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ceil.ll
@@ -0,0 +1,17 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for ceil are generated for half and float.
+
+define noundef half @test_ceil_half(half noundef %a) #0 {
+entry:
+  ; CHECK: call half @dx.op.unary.f16(i32 28, half %{{.*}})
+  %1 = call half @llvm.ceil.f16(half %a)
+  ret half %1
+}
+
+define noundef float @test_ceil_float(float noundef %a) #0 {
+entry:
+  ; CHECK: call float @dx.op.unary.f32(i32 28, float %{{.*}})
+  %1 = call float @llvm.ceil.f32(float %a)
+  ret float %1
+}
diff --git a/llvm/test/CodeGen/DirectX/fabs.ll b/llvm/test/CodeGen/DirectX/fabs.ll
new file mode 100644
index 0000000000000..25c3c6ed8a47d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/fabs.ll
@@ -0,0 +1,25 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for fabs are generated
+; for half, float and double.
+
+define noundef half @test_fabs_half(half noundef %a) #0 {
+entry:
+  ; CHECK: call half @dx.op.unary.f16(i32 6, half %{{.*}})
+  %1 = call half @llvm.fabs.f16(half %a)
+  ret half %1
+}
+
+define noundef float @test_fabs_float(float noundef %a) #0 {
+entry:
+  ; CHECK: call float @dx.op.unary.f32(i32 6, float %{{.*}})
+  %1 = call float @llvm.fabs.f32(float %a)
+  ret float %1
+}
+
+define noundef double @test_fabs_double(double noundef %a) #0 {
+entry:
+  ; CHECK: call double @dx.op.unary.f64(i32 6, double %{{.*}})
+  %1 = call double @llvm.fabs.f64(double %a)
+  ret double %1
+}
diff --git a/llvm/test/CodeGen/DirectX/floor.ll b/llvm/test/CodeGen/DirectX/floor.ll
new file mode 100644
index 0000000000000..df69804513f8d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/floor.ll
@@ -0,0 +1,17 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for floor are generated for half and float.
+
+define noundef half @test_floor_half(half noundef %a) #0 {
+entry:
+  ; CHECK: call half @dx.op.unary.f16(i32 27, half %{{.*}})
+  %1 = call half @llvm.floor.f16(half %a)
+  ret half %1
+}
+
+define noundef float @test_floor_float(float noundef %a) #0 {
+entry:
+  ; CHECK: call float @dx.op.unary.f32(i32 27, float %{{.*}})
+  %1 = call float @llvm.floor.f32(float %a)
+  ret float %1
+}



More information about the llvm-commits mailing list