[llvm] [DirectX][DXIL] Align type spec of TableGen DXIL Op and LLVM Intrinsic (PR #86311)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 27 14:24:59 PDT 2024


https://github.com/bharadwajy updated https://github.com/llvm/llvm-project/pull/86311

>From 972472813d0ad73f050bafd8763e66bb3516f331 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 22 Mar 2024 13:11:49 -0400
Subject: [PATCH] Align specification of return type and parameter type fields
 of DXIL Op mapping with those of TableGan class Intrinsic.

A void return type of LLVM Intrinsic is represented as [] in its
TableGen description record. Currently, a void return type of
DXIL Operation is represented as [llvm_void_ty]. In addition,
return and parameter types are recorded as a single list with
an understanding that element at index `0` is the return type.

These changes leverage and align DXIL Op type specification with
the type specification of the LLVM Intrinsic. As a result, return
and parameter types are now specified as two separate lists no
longer requiring a different representation for void return type.
Additionally, type specification would be more succinct yet
equally informative for DXIL Op records for which the same LLVM
Intrinsics types are also valid.

Added a test to verify lowering of LLVM intrinsic with void return.

Barrier intrinsic has a void return type. Specification of its
DXIL Op can inherit the types of this intrinsic. The test verifies
the changes.

Move OverloadKind to DXILABI.h.

Update definition names of enum Overload to follow naming conventions.
---
 llvm/include/llvm/IR/IntrinsicsDirectX.td |   1 +
 llvm/include/llvm/Support/DXILABI.h       |  15 ++
 llvm/lib/Target/DirectX/DXIL.td           |  65 +++--
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp |  55 ++--
 llvm/test/CodeGen/DirectX/barrier.ll      |  11 +
 llvm/utils/TableGen/DXILEmitter.cpp       | 289 +++++++++++++---------
 6 files changed, 260 insertions(+), 176 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/barrier.ll

diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index a871fac46b9fd5..ce8c25233ed566 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -16,6 +16,7 @@ def int_dx_thread_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrW
 def int_dx_group_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
 def int_dx_thread_id_in_group : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>;
 def int_dx_flattened_thread_id_in_group : Intrinsic<[llvm_i32_ty], [], [IntrNoMem, IntrWillReturn]>;
+def int_dx_barrier  : Intrinsic<[], [llvm_i32_ty], [IntrNoDuplicate, IntrWillReturn]>;
 
 def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">,
     Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>;
diff --git a/llvm/include/llvm/Support/DXILABI.h b/llvm/include/llvm/Support/DXILABI.h
index da4bea8fc46e3a..fdf140d125f879 100644
--- a/llvm/include/llvm/Support/DXILABI.h
+++ b/llvm/include/llvm/Support/DXILABI.h
@@ -39,6 +39,21 @@ enum class ParameterKind : uint8_t {
   DXILHandle,
 };
 
+enum OverloadKind : uint16_t {
+  Invalid = 0,
+  Void = 1,
+  Half = 1 << 1,
+  Float = 1 << 2,
+  Double = 1 << 3,
+  I1 = 1 << 4,
+  I8 = 1 << 5,
+  I16 = 1 << 6,
+  I32 = 1 << 7,
+  I64 = 1 << 8,
+  UserDefineType = 1 << 9,
+  ObjectType = 1 << 10,
+};
+
 /// The kind of resource for an SRV or UAV resource. Sometimes referred to as
 /// "Shape" in the DXIL docs.
 enum class ResourceKind : uint32_t {
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 1fd6f3ed044ecd..fa4cf446554cc3 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -240,18 +240,23 @@ class DXILOpMappingBase {
   DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = ?;         // LLVM Intrinsic DXIL Operation maps to
   string Doc = "";                     // A short description of the operation
-  list<LLVMType> OpTypes = ?;          // Valid types of DXIL Operation in the
-                                       // format [returnTy, param1ty, ...]
+  // The following fields denote the same semantics as those of Intrinsic class
+  // and are initialized with the same values as those of LLVMIntrinsic unless
+  // overridden in the definition of a record.
+  list<LLVMType> OpRetTypes = ?;    // Valid return types of DXIL Operation
+  list<LLVMType> OpParamTypes = ?;     // Valid parameter types of DXIL Operation
 }
 
 class DXILOpMapping<int opCode, DXILOpClass opClass,
                     Intrinsic intrinsic, string doc,
-                    list<LLVMType> opTys = []> : DXILOpMappingBase {
+                    list<LLVMType> retTys = [],
+                    list<LLVMType> paramTys = []> : DXILOpMappingBase {
   int OpCode = opCode;                 // Opcode corresponding to DXIL Operation
   DXILOpClass OpClass = opClass;       // Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
   string Doc = doc;                    // to a short description of the operation
-  list<LLVMType> OpTypes = !if(!eq(!size(opTys), 0), LLVMIntrinsic.Types, opTys);
+  list<LLVMType> OpRetTypes = !if(!eq(!size(retTys), 0), LLVMIntrinsic.RetTypes, retTys);
+  list<LLVMType> OpParamTypes = !if(!eq(!size(paramTys), 0), LLVMIntrinsic.ParamTypes, paramTys);
 }
 
 // Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
@@ -259,39 +264,39 @@ def Abs : DXILOpMapping<6, unary, int_fabs,
                          "Returns the absolute value of the input.">;
 def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf,
                          "Determines if the specified value is infinite.",
-                         [llvm_i1_ty, llvm_halforfloat_ty]>;
+                         [llvm_i1_ty], [llvm_halforfloat_ty]>;
 def Cos  : DXILOpMapping<12, unary, int_cos,
                          "Returns cosine(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Sin  : DXILOpMapping<13, unary, int_sin,
                          "Returns sine(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Exp2 : DXILOpMapping<21, unary, int_exp2,
                          "Returns the base 2 exponential, or 2**x, of the specified value."
                          "exp2(x) = 2**x.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Frac : DXILOpMapping<22, unary, int_dx_frac,
                          "Returns a fraction from 0 to 1 that represents the "
                          "decimal part of the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Log2 : DXILOpMapping<23, unary, int_log2,
                          "Returns the base-2 logarithm of the specified value.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Sqrt : DXILOpMapping<24, unary, int_sqrt,
                          "Returns the square root of the specified floating-point"
                          "value, per component.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
                          "Returns the reciprocal of the square root of the specified value."
                          "rsqrt(x) = 1 / sqrt(x).",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Round : DXILOpMapping<26, unary, int_round,
                          "Returns the input rounded to the nearest integer"
                          "within a floating-point type.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def Floor : DXILOpMapping<27, unary, int_floor,
                          "Returns the largest integer that is less than or equal to the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
+                         [llvm_halforfloat_ty], [LLVMMatchType<0>]>;
 def FMax : DXILOpMapping<35, binary, int_maxnum,
                          "Float maximum. FMax(a,b) = a > b ? a : b">;
 def FMin : DXILOpMapping<36, binary, int_minnum,
@@ -305,20 +310,28 @@ def UMax : DXILOpMapping<39, binary, int_umax,
 def UMin : DXILOpMapping<40, binary, int_umin,
                          "Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
 def FMad : DXILOpMapping<46, tertiary, int_fmuladd,
-                         "Floating point arithmetic multiply/add operation. fmad(m,a,b) = m * a + b.">;
+                         "Floating point arithmetic multiply/add operation. "
+                         "fmad(m,a,b) = m * a + b.">;
 def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
-                         "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
+                         "Signed integer arithmetic multiply/add operation. "
+                         "imad(m,a,b) = m * a + b.">;
 def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
-                         "Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
-let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)) in
-  def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
-                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1">;
-let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)) in
-  def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
-                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2">;
-let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)) in
-  def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
-                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3">;
+                         "Unsigned integer arithmetic multiply/add operation. "
+                         "umad(m,a,b) = m * a + b.">;
+def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
+                         "dot product of two float vectors Dot(a,b) = a[0]*b[0]"
+                         " + ... + a[n]*b[n] where n is between 0 and 1",
+                         [llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)>;
+def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
+                         "dot product of two float vectors Dot(a,b) = a[0]*b[0]"
+                         " + ... + a[n]*b[n] where n is between 0 and 2",
+                         [llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)>;
+def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
+                         "dot product of two float vectors Dot(a,b) = a[0]*b[0]"
+                         " + ... + a[n]*b[n] where n is between 0 and 3",
+                         [llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)>;
+def Barrier : DXILOpMapping<80, barrier, int_dx_barrier,
+                          "Inserts a memory barrier in the shader">;
 def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
                              "Reads the thread ID">;
 def GroupId  : DXILOpMapping<94, groupId, int_dx_group_id,
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 0b3982ea0f438a..22bbda461f1ae8 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -21,31 +21,13 @@ using namespace llvm::dxil;
 
 constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
 
-namespace {
-
-enum OverloadKind : uint16_t {
-  VOID = 1,
-  HALF = 1 << 1,
-  FLOAT = 1 << 2,
-  DOUBLE = 1 << 3,
-  I1 = 1 << 4,
-  I8 = 1 << 5,
-  I16 = 1 << 6,
-  I32 = 1 << 7,
-  I64 = 1 << 8,
-  UserDefineType = 1 << 9,
-  ObjectType = 1 << 10,
-};
-
-} // namespace
-
 static const char *getOverloadTypeName(OverloadKind Kind) {
   switch (Kind) {
-  case OverloadKind::HALF:
+  case OverloadKind::Half:
     return "f16";
-  case OverloadKind::FLOAT:
+  case OverloadKind::Float:
     return "f32";
-  case OverloadKind::DOUBLE:
+  case OverloadKind::Double:
     return "f64";
   case OverloadKind::I1:
     return "i1";
@@ -57,12 +39,15 @@ static const char *getOverloadTypeName(OverloadKind Kind) {
     return "i32";
   case OverloadKind::I64:
     return "i64";
-  case OverloadKind::VOID:
+  case OverloadKind::Void:
   case OverloadKind::ObjectType:
   case OverloadKind::UserDefineType:
     break;
+  case OverloadKind::Invalid:
+    report_fatal_error("Invalid Overload Type for type name lookup",
+                       /* gen_crash_diag=*/false);
   }
-  llvm_unreachable("invalid overload type for name");
+  llvm_unreachable("Unhandled Overload Type specified for type name lookup");
   return "void";
 }
 
@@ -70,13 +55,13 @@ static OverloadKind getOverloadKind(Type *Ty) {
   Type::TypeID T = Ty->getTypeID();
   switch (T) {
   case Type::VoidTyID:
-    return OverloadKind::VOID;
+    return OverloadKind::Void;
   case Type::HalfTyID:
-    return OverloadKind::HALF;
+    return OverloadKind::Half;
   case Type::FloatTyID:
-    return OverloadKind::FLOAT;
+    return OverloadKind::Float;
   case Type::DoubleTyID:
-    return OverloadKind::DOUBLE;
+    return OverloadKind::Double;
   case Type::IntegerTyID: {
     IntegerType *ITy = cast<IntegerType>(Ty);
     unsigned Bits = ITy->getBitWidth();
@@ -93,7 +78,7 @@ static OverloadKind getOverloadKind(Type *Ty) {
       return OverloadKind::I64;
     default:
       llvm_unreachable("invalid overload type");
-      return OverloadKind::VOID;
+      return OverloadKind::Void;
     }
   }
   case Type::PointerTyID:
@@ -102,7 +87,7 @@ static OverloadKind getOverloadKind(Type *Ty) {
     return OverloadKind::ObjectType;
   default:
     llvm_unreachable("invalid overload type");
-    return OverloadKind::VOID;
+    return OverloadKind::Void;
   }
 }
 
@@ -147,7 +132,7 @@ struct OpCodeProperty {
 
 static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
                                          const OpCodeProperty &Prop) {
-  if (Kind == OverloadKind::VOID) {
+  if (Kind == OverloadKind::Void) {
     return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
   }
   return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
@@ -157,7 +142,7 @@ static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
 
 static std::string constructOverloadTypeName(OverloadKind Kind,
                                              StringRef TypeName) {
-  if (Kind == OverloadKind::VOID)
+  if (Kind == OverloadKind::Void)
     return TypeName.str();
 
   assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
@@ -284,13 +269,13 @@ Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
   if (Prop->OverloadParamIndex < 0) {
     auto &Ctx = FT->getContext();
     switch (Prop->OverloadTys) {
-    case OverloadKind::VOID:
+    case OverloadKind::Void:
       return Type::getVoidTy(Ctx);
-    case OverloadKind::HALF:
+    case OverloadKind::Half:
       return Type::getHalfTy(Ctx);
-    case OverloadKind::FLOAT:
+    case OverloadKind::Float:
       return Type::getFloatTy(Ctx);
-    case OverloadKind::DOUBLE:
+    case OverloadKind::Double:
       return Type::getDoubleTy(Ctx);
     case OverloadKind::I1:
       return Type::getInt1Ty(Ctx);
diff --git a/llvm/test/CodeGen/DirectX/barrier.ll b/llvm/test/CodeGen/DirectX/barrier.ll
new file mode 100644
index 00000000000000..8be4aac1f782b5
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/barrier.ll
@@ -0,0 +1,11 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Argument of llvm.dx.barrier is expected to be a mask of 
+; DXIL::BarrierMode values. Chose an int value for testing.
+
+define void @test_barrier() #0 {
+entry:
+  ; CHECK: call void @dx.op.barrier.i32(i32 80, i32 9)
+  call void @llvm.dx.barrier(i32 noundef 9)
+  ret void
+}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index f2504775d557f2..b3772b92f2c235 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -39,8 +39,8 @@ struct DXILOperationDesc {
   int OpCode;         // ID of DXIL operation
   StringRef OpClass;  // name of the opcode class
   StringRef Doc;      // the documentation description of this instruction
-  SmallVector<Record *> OpTypes; // Vector of operand type records -
-                                 // return type is at index 0
+  SmallVector<OverloadKind> OpOverloadTys; // Vector of operand overload types -
+                                           // return type is at index 0
   SmallVector<std::string>
       OpAttributes;     // operation attribute represented as strings
   StringRef Intrinsic;  // The llvm intrinsic map to OpName. Default is "" which
@@ -65,41 +65,167 @@ struct DXILOperationDesc {
 };
 } // end anonymous namespace
 
-/// Return dxil::ParameterKind corresponding to input LLVMType record
+/// Return dxil::ParameterKind corresponding to input Overload Kind
 ///
-/// \param R TableGen def record of class LLVMType
+/// \param OLKind Overload Kind
 /// \return ParameterKind As defined in llvm/Support/DXILABI.h
 
-static ParameterKind getParameterKind(const Record *R) {
+static ParameterKind getParameterKind(const dxil::OverloadKind OLKind) {
+  switch (OLKind) {
+  case OverloadKind::Void:
+    return ParameterKind::Void;
+  case OverloadKind::Half:
+    return ParameterKind::Half;
+  case OverloadKind::Float:
+    return ParameterKind::Float;
+  case OverloadKind::Double:
+    return ParameterKind::Double;
+  case OverloadKind::I1:
+    return ParameterKind::I1;
+  case OverloadKind::I8:
+    return ParameterKind::I8;
+  case OverloadKind::I16:
+    return ParameterKind::I16;
+  case OverloadKind::I32:
+    return ParameterKind::I32;
+  case OverloadKind::I64:
+    return ParameterKind::I64;
+  default:
+    if ((OLKind ==
+         (OverloadKind::Half | OverloadKind::Float | OverloadKind::Double)) ||
+        (OLKind == (OverloadKind::Half | OverloadKind::Float)) ||
+        (OLKind == (OverloadKind::I1 | OverloadKind::I8 | OverloadKind::I16 |
+                    OverloadKind::I32 | OverloadKind::I64)) ||
+        (OLKind == (OverloadKind::I16 | OverloadKind::I32))) {
+      return ParameterKind::Overload;
+    } else {
+      report_fatal_error("Unsupported Overload Type encountered",
+                         /* gen_crash_diag=*/false);
+    }
+  }
+}
+
+/// Return a string representation of ParameterKind enum
+/// \param Kind Parameter Kind enum value
+/// \return std::string string representation of input Kind
+static std::string getParameterKindStr(ParameterKind Kind) {
+  switch (Kind) {
+  case ParameterKind::Invalid:
+    return "Invalid";
+  case ParameterKind::Void:
+    return "Void";
+  case ParameterKind::Half:
+    return "Half";
+  case ParameterKind::Float:
+    return "Float";
+  case ParameterKind::Double:
+    return "Double";
+  case ParameterKind::I1:
+    return "I1";
+  case ParameterKind::I8:
+    return "I8";
+  case ParameterKind::I16:
+    return "I16";
+  case ParameterKind::I32:
+    return "I32";
+  case ParameterKind::I64:
+    return "I64";
+  case ParameterKind::Overload:
+    return "Overload";
+  case ParameterKind::CBufferRet:
+    return "CBufferRet";
+  case ParameterKind::ResourceRet:
+    return "ResourceRet";
+  case ParameterKind::DXILHandle:
+    return "DXILHandle";
+  }
+  llvm_unreachable("Unknown llvm::dxil::ParameterKind enum");
+}
+
+static dxil::OverloadKind getOverloadKind(const Record *R) {
   auto VTRec = R->getValueAsDef("VT");
   switch (getValueType(VTRec)) {
   case MVT::isVoid:
-    return ParameterKind::Void;
+    return OverloadKind::Void;
   case MVT::f16:
-    return ParameterKind::Half;
+    return OverloadKind::Half;
   case MVT::f32:
-    return ParameterKind::Float;
+    return OverloadKind::Float;
   case MVT::f64:
-    return ParameterKind::Double;
+    return OverloadKind::Double;
   case MVT::i1:
-    return ParameterKind::I1;
+    return OverloadKind::I1;
   case MVT::i8:
-    return ParameterKind::I8;
+    return OverloadKind::I8;
   case MVT::i16:
-    return ParameterKind::I16;
+    return OverloadKind::I16;
   case MVT::i32:
-    return ParameterKind::I32;
-  case MVT::fAny:
+    return OverloadKind::I32;
+  case MVT::i64:
+    return OverloadKind::I64;
   case MVT::iAny:
-    return ParameterKind::Overload;
+    return static_cast<dxil::OverloadKind>(
+        OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64);
+  case MVT::fAny:
+    return static_cast<dxil::OverloadKind>(
+        OverloadKind::Half | OverloadKind::Float | OverloadKind::Double);
   case MVT::Other:
     // Handle DXIL-specific overload types
-    if (R->getValueAsInt("isHalfOrFloat") || R->getValueAsInt("isI16OrI32")) {
-      return ParameterKind::Overload;
+    {
+      if (R->getValueAsInt("isHalfOrFloat")) {
+        return static_cast<dxil::OverloadKind>(OverloadKind::Half |
+                                               OverloadKind::Float);
+      } else if (R->getValueAsInt("isI16OrI32")) {
+        return static_cast<dxil::OverloadKind>(OverloadKind::I16 |
+                                               OverloadKind::I32);
+      }
     }
     LLVM_FALLTHROUGH;
   default:
-    llvm_unreachable("Support for specified DXIL Type not yet implemented");
+    report_fatal_error(
+        "Support for specified parameter OverloadKind not yet implemented",
+        /* gen_crash_diag=*/false);
+  }
+}
+
+/// Return a string representation of OverloadKind enum
+/// \param OLKind Overload Kind
+/// \return std::string string representation of OverloadKind
+
+static std::string getOverloadKindStr(const dxil::OverloadKind OLKind) {
+  switch (OLKind) {
+  case OverloadKind::Void:
+    return "OverloadKind::Void";
+  case OverloadKind::Half:
+    return "OverloadKind::Half";
+  case OverloadKind::Float:
+    return "OverloadKind::Float";
+  case OverloadKind::Double:
+    return "OverloadKind::Double";
+  case OverloadKind::I1:
+    return "OverloadKind::I1";
+  case OverloadKind::I8:
+    return "OverloadKind::I8";
+  case OverloadKind::I16:
+    return "OverloadKind::I16";
+  case OverloadKind::I32:
+    return "OverloadKind::I32";
+  case OverloadKind::I64:
+    return "OverloadKind::I64";
+  default:
+    if (OLKind == (OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64)) {
+      return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64";
+    } else if (OLKind == (OverloadKind::Half | OverloadKind::Float |
+                          OverloadKind::Double)) {
+      return "OverloadKind::Half | OverloadKind::Float | OverloadKind::Double";
+    } else if (OLKind == (OverloadKind::Half | OverloadKind::Float)) {
+      return "OverloadKind::Half | OverloadKind::Float";
+    } else if (OLKind == (OverloadKind::I16 | OverloadKind::I32)) {
+      return "OverloadKind::I16 | OverloadKind::I32";
+    } else {
+      report_fatal_error("Unsupported OverloadKind specified",
+                         /* gen_crash_diag=*/false);
+    }
   }
 }
 
@@ -114,9 +240,25 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
 
   Doc = R->getValueAsString("Doc");
 
-  auto TypeRecs = R->getValueAsListOfDefs("OpTypes");
+  // Populate OpOverloadTys with return type and parameter types
+  auto RetTypeRecs = R->getValueAsListOfDefs("OpRetTypes");
+  auto ParamTypeRecs = R->getValueAsListOfDefs("OpParamTypes");
+  unsigned RetTypeRecSize = RetTypeRecs.size();
+  unsigned ParamTypeRecSize = ParamTypeRecs.size();
+  // A vector with return type and parameter type records
+  std::vector<Record *> TypeRecs;
+  TypeRecs.reserve(RetTypeRecSize + ParamTypeRecSize);
+  // If return type lust is empty, the return type is void
+  if (RetTypeRecSize == 0) {
+    OpOverloadTys.emplace_back(OverloadKind::Void);
+  } else {
+    // Append RetTypeRecs to TypeRecs
+    TypeRecs.insert(TypeRecs.end(), RetTypeRecs.begin(), RetTypeRecs.end());
+  }
+  // Append RetTypeRecs to TypeRecs
+  TypeRecs.insert(TypeRecs.end(), ParamTypeRecs.begin(), ParamTypeRecs.end());
+
   unsigned TypeRecsSize = TypeRecs.size();
-  // Populate OpTypes with return type and parameter types
 
   // Parameter indices of overloaded parameters.
   // This vector contains overload parameters in the order used to
@@ -146,13 +288,13 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
         if (!knownType) {
           report_fatal_error("Specification of multiple differing overload "
                              "parameter types not yet supported",
-                             false);
+                             /* gen_crash_diag=*/false);
         }
       } else {
         OverloadParamIndices.push_back(i);
       }
     }
-    // Populate OpTypes array according to the type specification
+    // Populate OpOverloadTys array according to the type specification
     if (TR->isAnonymous()) {
       // Check prior overload types exist
       assert(!OverloadParamIndices.empty() &&
@@ -160,10 +302,10 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
       // Get the parameter index of anonymous type, TR, references
       auto OLParamIndex = TR->getValueAsInt("Number");
       // Resolve and insert the type to that at OLParamIndex
-      OpTypes.emplace_back(TypeRecs[OLParamIndex]);
+      OpOverloadTys.emplace_back(getOverloadKind(TypeRecs[OLParamIndex]));
     } else {
-      // A non-anonymous type. Just record it in OpTypes
-      OpTypes.emplace_back(TR);
+      // A non-anonymous type. Just record it in OpOverloadTys
+      OpOverloadTys.emplace_back(getOverloadKind(TR));
     }
   }
 
@@ -172,7 +314,7 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   if (!OverloadParamIndices.empty()) {
     if (OverloadParamIndices.size() > 1)
       report_fatal_error("Multiple overload type specification not supported",
-                         false);
+                         /* gen_crash_diag=*/false);
     OverloadParamIndex = OverloadParamIndices[0];
   }
   // Get the operation class
@@ -196,89 +338,6 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   }
 }
 
-/// Return a string representation of ParameterKind enum
-/// \param Kind Parameter Kind enum value
-/// \return std::string string representation of input Kind
-static std::string getParameterKindStr(ParameterKind Kind) {
-  switch (Kind) {
-  case ParameterKind::Invalid:
-    return "Invalid";
-  case ParameterKind::Void:
-    return "Void";
-  case ParameterKind::Half:
-    return "Half";
-  case ParameterKind::Float:
-    return "Float";
-  case ParameterKind::Double:
-    return "Double";
-  case ParameterKind::I1:
-    return "I1";
-  case ParameterKind::I8:
-    return "I8";
-  case ParameterKind::I16:
-    return "I16";
-  case ParameterKind::I32:
-    return "I32";
-  case ParameterKind::I64:
-    return "I64";
-  case ParameterKind::Overload:
-    return "Overload";
-  case ParameterKind::CBufferRet:
-    return "CBufferRet";
-  case ParameterKind::ResourceRet:
-    return "ResourceRet";
-  case ParameterKind::DXILHandle:
-    return "DXILHandle";
-  }
-  llvm_unreachable("Unknown llvm::dxil::ParameterKind enum");
-}
-
-/// Return a string representation of OverloadKind enum that maps to
-/// input LLVMType record
-/// \param R TableGen def record of class LLVMType
-/// \return std::string string representation of OverloadKind
-
-static std::string getOverloadKindStr(const Record *R) {
-  auto VTRec = R->getValueAsDef("VT");
-  switch (getValueType(VTRec)) {
-  case MVT::isVoid:
-    return "OverloadKind::VOID";
-  case MVT::f16:
-    return "OverloadKind::HALF";
-  case MVT::f32:
-    return "OverloadKind::FLOAT";
-  case MVT::f64:
-    return "OverloadKind::DOUBLE";
-  case MVT::i1:
-    return "OverloadKind::I1";
-  case MVT::i8:
-    return "OverloadKind::I8";
-  case MVT::i16:
-    return "OverloadKind::I16";
-  case MVT::i32:
-    return "OverloadKind::I32";
-  case MVT::i64:
-    return "OverloadKind::I64";
-  case MVT::iAny:
-    return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64";
-  case MVT::fAny:
-    return "OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE";
-  case MVT::Other:
-    // Handle DXIL-specific overload types
-    {
-      if (R->getValueAsInt("isHalfOrFloat")) {
-        return "OverloadKind::HALF | OverloadKind::FLOAT";
-      } else if (R->getValueAsInt("isI16OrI32")) {
-        return "OverloadKind::I16 | OverloadKind::I32";
-      }
-    }
-    LLVM_FALLTHROUGH;
-  default:
-    llvm_unreachable(
-        "Support for specified parameter OverloadKind not yet implemented");
-  }
-}
-
 /// Emit Enums of DXIL Ops
 /// \param A vector of DXIL Ops
 /// \param Output stream
@@ -376,8 +435,8 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
     OpClassStrings.add(Op.OpClass.data());
     SmallVector<ParameterKind> ParamKindVec;
     // ParamKindVec is a vector of parameters. Skip return type at index 0
-    for (unsigned i = 1; i < Op.OpTypes.size(); i++) {
-      ParamKindVec.emplace_back(getParameterKind(Op.OpTypes[i]));
+    for (unsigned i = 1; i < Op.OpOverloadTys.size(); i++) {
+      ParamKindVec.emplace_back(getParameterKind(Op.OpOverloadTys[i]));
     }
     ParameterMap[Op.OpClass] = ParamKindVec;
     Parameters.add(ParamKindVec);
@@ -391,7 +450,7 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
   // Emit the DXIL operation table.
   //{dxil::OpCode::Sin, OpCodeNameIndex, OpCodeClass::unary,
   // OpCodeClassNameIndex,
-  // OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone, 0,
+  // OverloadKind::Float | OverloadKind::Half, Attribute::AttrKind::ReadNone, 0,
   // 3, ParameterTableOffset},
   OS << "static const OpCodeProperty *getOpCodeProperty(dxil::OpCode Op) "
         "{\n";
@@ -406,14 +465,14 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
     // return type - as overload parameter to emit the appropriate overload kind
     // enum.
     if (OLParamIdx < 0) {
-      OLParamIdx = (Op.OpTypes.size() > 1) ? 1 : 0;
+      OLParamIdx = (Op.OpOverloadTys.size() > 1) ? 1 : 0;
     }
     OS << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
        << ", OpCodeClass::" << Op.OpClass << ", "
        << OpClassStrings.get(Op.OpClass.data()) << ", "
-       << getOverloadKindStr(Op.OpTypes[OLParamIdx]) << ", "
+       << getOverloadKindStr(Op.OpOverloadTys[OLParamIdx]) << ", "
        << emitDXILOperationAttr(Op.OpAttributes) << ", "
-       << Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
+       << Op.OverloadParamIndex << ", " << Op.OpOverloadTys.size() - 1 << ", "
        << Parameters.get(ParameterMap[Op.OpClass]) << " },\n";
   }
   OS << "  };\n";



More information about the llvm-commits mailing list