[llvm] [DirectX][NFC] Change specification of overload types and attribute in DXIL.td (PR #81184)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 8 11:57:14 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: S. Bharadwaj Yadavalli (bharadwajy)

<details>
<summary>Changes</summary>

 - Specify overload types of DXIL Operation as list of types instead of a string.
 - Add supported DXIL type record definitions to `DXIL.td` leveraging `LLVMType` to avoid duplicate definitions.
 - Spell out DXIL Operation Attribute specification string.
 - Make corresponding changes to process the records in DXILEmitter.cpp

---
Full diff: https://github.com/llvm/llvm-project/pull/81184.diff


2 Files Affected:

- (modified) llvm/lib/Target/DirectX/DXIL.td (+42-11) 
- (modified) llvm/utils/TableGen/DXILEmitter.cpp (+90-53) 


``````````diff
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 3f3ace5a1a3a36..4b09c9597e2228 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -34,12 +34,42 @@ def BinaryUintCategory : DXILOpCategory<"Binary uint">;
 def UnaryFloatCategory : DXILOpCategory<"Unary float">;
 def ComputeIDCategory : DXILOpCategory<"Compute/Mesh/Amplification shader">;
 
+// ValueTypes specific to DXIL
+// Define Overload value type as an entity with no size and an arbitrary value
+// of 1024 - assuming that all currently defined values are less than 1024
+def overloadVal       : ValueType<0, 1024>;
+def resourceRetVal    : ValueType<0, 1025>;
+def cbufferRetVal     : ValueType<0, 1026>;
+def handleVal         : ValueType<0, 1027>;
+
+// Following are the scalar types supported by DXIL operations and are synonymous
+// to llvm_*_ty defined for readability and ease of use in the context of this file.
+
+def voidTy  : LLVMType<isVoid>;
+
+// Floating point types
+def f16Ty   : LLVMType<f16>;
+def f32Ty   : LLVMType<f32>;
+def f64Ty   : LLVMType<f64>;
+
+// Integer types
+def i1Ty   : LLVMType<i1>;
+def i8Ty   : LLVMType<i8>;
+def i16Ty  : LLVMType<i16>;
+def i32Ty  : LLVMType<i32>;
+def i64Ty  : LLVMType<i64>;
+
+def overloadTy        : LLVMType<overloadVal>;
+def resourceRetTy     : LLVMType<resourceRetVal>;
+def cbufferRetTy      : LLVMType<cbufferRetVal>;
+def handleTy          : LLVMType<handleVal>;
+
 // The parameter description for a DXIL operation
 class DXILOpParameter<int pos, string type, string name, string doc,
                  bit isConstant = 0, string enumName = "",
                  int maxValue = 0> {
   int Pos = pos;               // Position in parameter list
-  string LLVMType = type;      // LLVM type name, $o for overload, $r for resource
+  string Type = type;          // LLVM type name, $o for overload, $r for resource
                                // type, $cb for legacy cbuffer, $u4 for u4 struct
   string Name = name;          // Short, unique parameter name
   string Doc = doc;            // Description of this parameter
@@ -56,9 +86,10 @@ class DXILOperationDesc {
   DXILOpCategory OpCategory;  // Category of the operation
   string Doc = "";            // Description of the operation
   list<DXILOpParameter> Params = []; // Parameter list of the operation
-  string OverloadTypes = "";  // Overload types, if applicable
-  string Attributes = "";     // Attribute shorthands: rn=does not access
-                              // memory,ro=only reads from memory,
+  list<LLVMType> OverloadTypes = [];  // Overload types, if applicable
+  string Attributes = "";     // Operation Attribute
+                              // "NoReadMemory" - does not read memory
+                              // "ReadMemory"   - reads memory
   bit IsDerivative = 0;       // Whether this is some kind of derivative
   bit IsGradient = 0;         // Whether this requires a gradient calculation
   bit IsFeedback = 0;         // Whether this is a sampler feedback operation
@@ -71,7 +102,7 @@ class DXILOperationDesc {
 }
 
 class DXILOperation<string name, int opCode, DXILOpClass opClass, DXILOpCategory opCategory, string doc,
-              string oloadTypes, string attrs, list<DXILOpParameter> params,
+              list<LLVMType> oloadTypes, string attrs, list<DXILOpParameter> params,
               list<string> statsGroup = []> : DXILOperationDesc {
   let OpName = name;
   let OpCode = opCode;
@@ -88,7 +119,7 @@ class DXILOperation<string name, int opCode, DXILOpClass opClass, DXILOpCategory
 class LLVMIntrinsic<Intrinsic llvm_intrinsic_> { Intrinsic llvm_intrinsic = llvm_intrinsic_; }
 
 def Sin : DXILOperation<"Sin", 13, UnaryClass, UnaryFloatCategory, "returns sine(theta) for theta in radians.",
-  "half;float;", "rn",
+  [f16Ty,f32Ty], "NoReadMemory",
   [
     DXILOpParameter<0, "$o", "", "operation result">,
     DXILOpParameter<1, "i32", "opcode", "DXIL opcode">,
@@ -98,7 +129,7 @@ def Sin : DXILOperation<"Sin", 13, UnaryClass, UnaryFloatCategory, "returns sine
   LLVMIntrinsic<int_sin>;
 
 def UMax : DXILOperation< "UMax", 39,  BinaryClass,  BinaryUintCategory, "unsigned integer maximum. UMax(a,b) = a > b ? a : b",
-    "i16;i32;i64;",  "rn",
+    [i16Ty,i32Ty,i64Ty],  "NoReadMemory",
   [
     DXILOpParameter<0,  "$o",  "",  "operation result">,
     DXILOpParameter<1,  "i32",  "opcode",  "DXIL opcode">,
@@ -108,7 +139,7 @@ def UMax : DXILOperation< "UMax", 39,  BinaryClass,  BinaryUintCategory, "unsign
   ["uints"]>,
   LLVMIntrinsic<int_umax>;
 
-def ThreadId : DXILOperation< "ThreadId", 93,  ThreadIdClass, ComputeIDCategory, "reads the thread ID", "i32;",  "rn",
+def ThreadId : DXILOperation< "ThreadId", 93,  ThreadIdClass, ComputeIDCategory, "reads the thread ID", [i32Ty],  "NoReadMemory",
   [
     DXILOpParameter<0,  "i32",  "",  "thread ID component">,
     DXILOpParameter<1,  "i32",  "opcode",  "DXIL opcode">,
@@ -116,7 +147,7 @@ def ThreadId : DXILOperation< "ThreadId", 93,  ThreadIdClass, ComputeIDCategory,
   ]>,
   LLVMIntrinsic<int_dx_thread_id>;
 
-def GroupId : DXILOperation< "GroupId", 94,  GroupIdClass, ComputeIDCategory, "reads the group ID (SV_GroupID)", "i32;",  "rn",
+def GroupId : DXILOperation< "GroupId", 94,  GroupIdClass, ComputeIDCategory, "reads the group ID (SV_GroupID)", [i32Ty],  "NoReadMemory",
   [
     DXILOpParameter<0,  "i32",  "",  "group ID component">,
     DXILOpParameter<1,  "i32",  "opcode",  "DXIL opcode">,
@@ -125,7 +156,7 @@ def GroupId : DXILOperation< "GroupId", 94,  GroupIdClass, ComputeIDCategory, "r
   LLVMIntrinsic<int_dx_group_id>;
 
 def ThreadIdInGroup : DXILOperation< "ThreadIdInGroup", 95,  ThreadIdInGroupClass, ComputeIDCategory,
-  "reads the thread ID within the group (SV_GroupThreadID)", "i32;",  "rn",
+  "reads the thread ID within the group (SV_GroupThreadID)", [i32Ty],  "NoReadMemory",
   [
     DXILOpParameter<0,  "i32",  "",  "thread ID in group component">,
     DXILOpParameter<1,  "i32",  "opcode",  "DXIL opcode">,
@@ -134,7 +165,7 @@ def ThreadIdInGroup : DXILOperation< "ThreadIdInGroup", 95,  ThreadIdInGroupClas
   LLVMIntrinsic<int_dx_thread_id_in_group>;
 
 def FlattenedThreadIdInGroup : DXILOperation< "FlattenedThreadIdInGroup", 96,  FlattenedThreadIdInGroupClass, ComputeIDCategory,
-   "provides a flattened index for a given thread within a given group (SV_GroupIndex)", "i32;",  "rn",
+   "provides a flattened index for a given thread within a given group (SV_GroupIndex)", [i32Ty],  "NoReadMemory",
   [
     DXILOpParameter<0,  "i32",  "",  "result">,
     DXILOpParameter<1,  "i32",  "opcode",  "DXIL opcode">
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index cb9f9c6b03c636..32c5f4ff16f004 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -49,7 +49,7 @@ struct DXILOperationDesc {
   StringRef Doc;       // the documentation description of this instruction
 
   SmallVector<DXILParameter> Params; // the operands that this instruction takes
-  StringRef OverloadTypes;       // overload types if applicable
+  SmallVector<ParameterKind> OverloadTypes; // overload types if applicable
   StringRef FnAttr;              // attribute shorthands: rn=does not access
                                  // memory,ro=only reads from memory
   StringRef Intrinsic; // The llvm intrinsic map to OpName. Default is "" which
@@ -69,37 +69,31 @@ struct DXILOperationDesc {
   int OverloadParamIndex; // parameter index which control the overload.
                           // When < 0, should be only 1 overload type.
   SmallVector<StringRef, 4> counters; // counters for this inst.
-  DXILOperationDesc(const Record *R) {
-    OpName = R->getValueAsString("OpName");
-    OpCode = R->getValueAsInt("OpCode");
-    OpClass = R->getValueAsDef("OpClass")->getValueAsString("Name");
-    Category = R->getValueAsDef("OpCategory")->getValueAsString("Name");
-
-    if (R->getValue("llvm_intrinsic")) {
-      auto *IntrinsicDef = R->getValueAsDef("llvm_intrinsic");
-      auto DefName = IntrinsicDef->getName();
-      assert(DefName.starts_with("int_") && "invalid intrinsic name");
-      // Remove the int_ from intrinsic name.
-      Intrinsic = DefName.substr(4);
-    }
-
-    Doc = R->getValueAsString("Doc");
-
-    ListInit *ParamList = R->getValueAsListInit("Params");
-    OverloadParamIndex = -1;
-    for (unsigned I = 0; I < ParamList->size(); ++I) {
-      Record *Param = ParamList->getElementAsRecord(I);
-      Params.emplace_back(DXILParameter(Param));
-      auto &CurParam = Params.back();
-      if (CurParam.Kind >= ParameterKind::OVERLOAD)
-        OverloadParamIndex = I;
-    }
-    OverloadTypes = R->getValueAsString("OverloadTypes");
-    FnAttr = R->getValueAsString("Attributes");
-  }
+  DXILOperationDesc(const Record *);
 };
 } // end anonymous namespace
 
+// Convert DXIL type name string to dxil::ParameterKind
+// @param typeNameStr Type name string
+// @return ParameterKind as defined in llvm/Support/DXILABI.h
+static ParameterKind getDXILTypeNameToKind(StringRef typeNameStr) {
+  return StringSwitch<ParameterKind>(typeNameStr)
+      .Case("voidTy", ParameterKind::VOID)
+      .Case("f16Ty", ParameterKind::HALF)
+      .Case("f32Ty", ParameterKind::FLOAT)
+      .Case("f64Ty", ParameterKind::DOUBLE)
+      .Case("i1Ty", ParameterKind::I1)
+      .Case("i8Ty", ParameterKind::I8)
+      .Case("i16Ty", ParameterKind::I16)
+      .Case("i32Ty", ParameterKind::I32)
+      .Case("i64Ty", ParameterKind::I64)
+      .Case("overloadTy", ParameterKind::OVERLOAD)
+      .Case("handleTy", ParameterKind::DXIL_HANDLE)
+      .Case("cbufferRetTy", ParameterKind::CBUFFER_RET)
+      .Case("resourceRetTy", ParameterKind::RESOURCE_RET)
+      .Default(ParameterKind::INVALID);
+}
+
 static ParameterKind parameterTypeNameToKind(StringRef Name) {
   return StringSwitch<ParameterKind>(Name)
       .Case("void", ParameterKind::VOID)
@@ -118,10 +112,44 @@ static ParameterKind parameterTypeNameToKind(StringRef Name) {
       .Default(ParameterKind::INVALID);
 }
 
+DXILOperationDesc::DXILOperationDesc(const Record *R) {
+  OpName = R->getValueAsString("OpName");
+  OpCode = R->getValueAsInt("OpCode");
+  OpClass = R->getValueAsDef("OpClass")->getValueAsString("Name");
+  Category = R->getValueAsDef("OpCategory")->getValueAsString("Name");
+
+  if (R->getValue("llvm_intrinsic")) {
+    auto *IntrinsicDef = R->getValueAsDef("llvm_intrinsic");
+    auto DefName = IntrinsicDef->getName();
+    assert(DefName.starts_with("int_") && "invalid intrinsic name");
+    // Remove the int_ from intrinsic name.
+    Intrinsic = DefName.substr(4);
+  }
+
+  Doc = R->getValueAsString("Doc");
+
+  ListInit *ParamList = R->getValueAsListInit("Params");
+  OverloadParamIndex = -1;
+  for (unsigned I = 0; I < ParamList->size(); ++I) {
+    Record *Param = ParamList->getElementAsRecord(I);
+    Params.emplace_back(DXILParameter(Param));
+    auto &CurParam = Params.back();
+    if (CurParam.Kind >= ParameterKind::OVERLOAD)
+      OverloadParamIndex = I;
+  }
+  ListInit *OverloadTypeList = R->getValueAsListInit("OverloadTypes");
+
+  for (unsigned I = 0; I < OverloadTypeList->size(); ++I) {
+    Record *R = OverloadTypeList->getElementAsRecord(I);
+    OverloadTypes.emplace_back(getDXILTypeNameToKind(R->getNameInitAsString()));
+  }
+  FnAttr = R->getValueAsString("Attributes");
+}
+
 DXILParameter::DXILParameter(const Record *R) {
   Name = R->getValueAsString("Name");
   Pos = R->getValueAsInt("Pos");
-  Kind = parameterTypeNameToKind(R->getValueAsString("LLVMType"));
+  Kind = parameterTypeNameToKind(R->getValueAsString("Type"));
   if (R->getValue("Doc"))
     Doc = R->getValueAsString("Doc");
   IsConst = R->getValueAsBit("IsConstant");
@@ -268,36 +296,45 @@ static void emitDXILIntrinsicMap(std::vector<DXILOperationDesc> &Ops,
 
 static std::string emitDXILOperationFnAttr(StringRef FnAttr) {
   return StringSwitch<std::string>(FnAttr)
-      .Case("rn", "Attribute::ReadNone")
-      .Case("ro", "Attribute::ReadOnly")
+      .Case("NoReadMemory", "Attribute::ReadNone")
+      .Case("ReadMemory", "Attribute::ReadOnly")
       .Default("Attribute::None");
 }
 
-static std::string getOverloadKind(StringRef Overload) {
-  return StringSwitch<std::string>(Overload)
-      .Case("half", "OverloadKind::HALF")
-      .Case("float", "OverloadKind::FLOAT")
-      .Case("double", "OverloadKind::DOUBLE")
-      .Case("i1", "OverloadKind::I1")
-      .Case("i16", "OverloadKind::I16")
-      .Case("i32", "OverloadKind::I32")
-      .Case("i64", "OverloadKind::I64")
-      .Case("udt", "OverloadKind::UserDefineType")
-      .Case("obj", "OverloadKind::ObjectType")
-      .Default("OverloadKind::VOID");
+static std::string overloadKindStr(ParameterKind Overload) {
+  switch (Overload) {
+  case ParameterKind::HALF:
+    return "OverloadKind::HALF";
+  case ParameterKind::FLOAT:
+    return "OverloadKind::FLOAT";
+  case ParameterKind::DOUBLE:
+    return "OverloadKind::DOUBLE";
+  case ParameterKind::I1:
+    return "OverloadKind::I1";
+  case ParameterKind::I8:
+    return "OverloadKind::I8";
+  case ParameterKind::I16:
+    return "OverloadKind::I16";
+  case ParameterKind::I32:
+    return "OverloadKind::I32";
+  case ParameterKind::I64:
+    return "OverloadKind::I64";
+  case ParameterKind::VOID:
+    return "OverloadKind::VOID";
+  default:
+    return "OverloadKind::UNKNOWN";
+  }
 }
 
-static std::string getDXILOperationOverload(StringRef Overloads) {
-  SmallVector<StringRef> OverloadStrs;
-  Overloads.split(OverloadStrs, ';', /*MaxSplit*/ -1, /*KeepEmpty*/ false);
+static std::string
+getDXILOperationOverloads(SmallVector<ParameterKind> Overloads) {
   // Format is: OverloadKind::FLOAT | OverloadKind::HALF
-  assert(!OverloadStrs.empty() && "Invalid overloads");
-  auto It = OverloadStrs.begin();
+  auto It = Overloads.begin();
   std::string Result;
   raw_string_ostream OS(Result);
-  OS << getOverloadKind(*It);
-  for (++It; It != OverloadStrs.end(); ++It) {
-    OS << " | " << getOverloadKind(*It);
+  OS << overloadKindStr(*It);
+  for (++It; It != Overloads.end(); ++It) {
+    OS << " | " << overloadKindStr(*It);
   }
   return OS.str();
 }
@@ -367,7 +404,7 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
     OS << "  { dxil::OpCode::" << Op.OpName << ", "
        << OpStrings.get(Op.OpName.str()) << ", OpCodeClass::" << Op.OpClass
        << ", " << OpClassStrings.get(getDXILOpClassName(Op.OpClass)) << ", "
-       << getDXILOperationOverload(Op.OverloadTypes) << ", "
+       << getDXILOperationOverloads(Op.OverloadTypes) << ", "
        << emitDXILOperationFnAttr(Op.FnAttr) << ", " << Op.OverloadParamIndex
        << ", " << Op.Params.size() << ", "
        << Parameters.get(ParameterMap[Op.OpClass]) << " },\n";

``````````

</details>


https://github.com/llvm/llvm-project/pull/81184


More information about the llvm-commits mailing list