[llvm] [DirectX][NFC] Change specification of overload types and attribute in DXIL.td (PR #81184)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 8 11:57:14 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-directx
Author: S. Bharadwaj Yadavalli (bharadwajy)
<details>
<summary>Changes</summary>
- Specify overload types of DXIL Operation as list of types instead of a string.
- Add supported DXIL type record definitions to `DXIL.td` leveraging `LLVMType` to avoid duplicate definitions.
- Spell out DXIL Operation Attribute specification string.
- Make corresponding changes to process the records in DXILEmitter.cpp
---
Full diff: https://github.com/llvm/llvm-project/pull/81184.diff
2 Files Affected:
- (modified) llvm/lib/Target/DirectX/DXIL.td (+42-11)
- (modified) llvm/utils/TableGen/DXILEmitter.cpp (+90-53)
``````````diff
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 3f3ace5a1a3a36..4b09c9597e2228 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -34,12 +34,42 @@ def BinaryUintCategory : DXILOpCategory<"Binary uint">;
def UnaryFloatCategory : DXILOpCategory<"Unary float">;
def ComputeIDCategory : DXILOpCategory<"Compute/Mesh/Amplification shader">;
+// ValueTypes specific to DXIL
+// Define Overload value type as an entity with no size and an arbitrary value
+// of 1024 - assuming that all currently defined values are less than 1024
+def overloadVal : ValueType<0, 1024>;
+def resourceRetVal : ValueType<0, 1025>;
+def cbufferRetVal : ValueType<0, 1026>;
+def handleVal : ValueType<0, 1027>;
+
+// Following are the scalar types supported by DXIL operations and are synonymous
+// to llvm_*_ty defined for readability and ease of use in the context of this file.
+
+def voidTy : LLVMType<isVoid>;
+
+// Floating point types
+def f16Ty : LLVMType<f16>;
+def f32Ty : LLVMType<f32>;
+def f64Ty : LLVMType<f64>;
+
+// Integer types
+def i1Ty : LLVMType<i1>;
+def i8Ty : LLVMType<i8>;
+def i16Ty : LLVMType<i16>;
+def i32Ty : LLVMType<i32>;
+def i64Ty : LLVMType<i64>;
+
+def overloadTy : LLVMType<overloadVal>;
+def resourceRetTy : LLVMType<resourceRetVal>;
+def cbufferRetTy : LLVMType<cbufferRetVal>;
+def handleTy : LLVMType<handleVal>;
+
// The parameter description for a DXIL operation
class DXILOpParameter<int pos, string type, string name, string doc,
bit isConstant = 0, string enumName = "",
int maxValue = 0> {
int Pos = pos; // Position in parameter list
- string LLVMType = type; // LLVM type name, $o for overload, $r for resource
+ string Type = type; // LLVM type name, $o for overload, $r for resource
// type, $cb for legacy cbuffer, $u4 for u4 struct
string Name = name; // Short, unique parameter name
string Doc = doc; // Description of this parameter
@@ -56,9 +86,10 @@ class DXILOperationDesc {
DXILOpCategory OpCategory; // Category of the operation
string Doc = ""; // Description of the operation
list<DXILOpParameter> Params = []; // Parameter list of the operation
- string OverloadTypes = ""; // Overload types, if applicable
- string Attributes = ""; // Attribute shorthands: rn=does not access
- // memory,ro=only reads from memory,
+ list<LLVMType> OverloadTypes = []; // Overload types, if applicable
+ string Attributes = ""; // Operation Attribute
+ // "NoReadMemory" - does not read memory
+ // "ReadMemory" - reads memory
bit IsDerivative = 0; // Whether this is some kind of derivative
bit IsGradient = 0; // Whether this requires a gradient calculation
bit IsFeedback = 0; // Whether this is a sampler feedback operation
@@ -71,7 +102,7 @@ class DXILOperationDesc {
}
class DXILOperation<string name, int opCode, DXILOpClass opClass, DXILOpCategory opCategory, string doc,
- string oloadTypes, string attrs, list<DXILOpParameter> params,
+ list<LLVMType> oloadTypes, string attrs, list<DXILOpParameter> params,
list<string> statsGroup = []> : DXILOperationDesc {
let OpName = name;
let OpCode = opCode;
@@ -88,7 +119,7 @@ class DXILOperation<string name, int opCode, DXILOpClass opClass, DXILOpCategory
class LLVMIntrinsic<Intrinsic llvm_intrinsic_> { Intrinsic llvm_intrinsic = llvm_intrinsic_; }
def Sin : DXILOperation<"Sin", 13, UnaryClass, UnaryFloatCategory, "returns sine(theta) for theta in radians.",
- "half;float;", "rn",
+ [f16Ty,f32Ty], "NoReadMemory",
[
DXILOpParameter<0, "$o", "", "operation result">,
DXILOpParameter<1, "i32", "opcode", "DXIL opcode">,
@@ -98,7 +129,7 @@ def Sin : DXILOperation<"Sin", 13, UnaryClass, UnaryFloatCategory, "returns sine
LLVMIntrinsic<int_sin>;
def UMax : DXILOperation< "UMax", 39, BinaryClass, BinaryUintCategory, "unsigned integer maximum. UMax(a,b) = a > b ? a : b",
- "i16;i32;i64;", "rn",
+ [i16Ty,i32Ty,i64Ty], "NoReadMemory",
[
DXILOpParameter<0, "$o", "", "operation result">,
DXILOpParameter<1, "i32", "opcode", "DXIL opcode">,
@@ -108,7 +139,7 @@ def UMax : DXILOperation< "UMax", 39, BinaryClass, BinaryUintCategory, "unsign
["uints"]>,
LLVMIntrinsic<int_umax>;
-def ThreadId : DXILOperation< "ThreadId", 93, ThreadIdClass, ComputeIDCategory, "reads the thread ID", "i32;", "rn",
+def ThreadId : DXILOperation< "ThreadId", 93, ThreadIdClass, ComputeIDCategory, "reads the thread ID", [i32Ty], "NoReadMemory",
[
DXILOpParameter<0, "i32", "", "thread ID component">,
DXILOpParameter<1, "i32", "opcode", "DXIL opcode">,
@@ -116,7 +147,7 @@ def ThreadId : DXILOperation< "ThreadId", 93, ThreadIdClass, ComputeIDCategory,
]>,
LLVMIntrinsic<int_dx_thread_id>;
-def GroupId : DXILOperation< "GroupId", 94, GroupIdClass, ComputeIDCategory, "reads the group ID (SV_GroupID)", "i32;", "rn",
+def GroupId : DXILOperation< "GroupId", 94, GroupIdClass, ComputeIDCategory, "reads the group ID (SV_GroupID)", [i32Ty], "NoReadMemory",
[
DXILOpParameter<0, "i32", "", "group ID component">,
DXILOpParameter<1, "i32", "opcode", "DXIL opcode">,
@@ -125,7 +156,7 @@ def GroupId : DXILOperation< "GroupId", 94, GroupIdClass, ComputeIDCategory, "r
LLVMIntrinsic<int_dx_group_id>;
def ThreadIdInGroup : DXILOperation< "ThreadIdInGroup", 95, ThreadIdInGroupClass, ComputeIDCategory,
- "reads the thread ID within the group (SV_GroupThreadID)", "i32;", "rn",
+ "reads the thread ID within the group (SV_GroupThreadID)", [i32Ty], "NoReadMemory",
[
DXILOpParameter<0, "i32", "", "thread ID in group component">,
DXILOpParameter<1, "i32", "opcode", "DXIL opcode">,
@@ -134,7 +165,7 @@ def ThreadIdInGroup : DXILOperation< "ThreadIdInGroup", 95, ThreadIdInGroupClas
LLVMIntrinsic<int_dx_thread_id_in_group>;
def FlattenedThreadIdInGroup : DXILOperation< "FlattenedThreadIdInGroup", 96, FlattenedThreadIdInGroupClass, ComputeIDCategory,
- "provides a flattened index for a given thread within a given group (SV_GroupIndex)", "i32;", "rn",
+ "provides a flattened index for a given thread within a given group (SV_GroupIndex)", [i32Ty], "NoReadMemory",
[
DXILOpParameter<0, "i32", "", "result">,
DXILOpParameter<1, "i32", "opcode", "DXIL opcode">
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index cb9f9c6b03c636..32c5f4ff16f004 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -49,7 +49,7 @@ struct DXILOperationDesc {
StringRef Doc; // the documentation description of this instruction
SmallVector<DXILParameter> Params; // the operands that this instruction takes
- StringRef OverloadTypes; // overload types if applicable
+ SmallVector<ParameterKind> OverloadTypes; // overload types if applicable
StringRef FnAttr; // attribute shorthands: rn=does not access
// memory,ro=only reads from memory
StringRef Intrinsic; // The llvm intrinsic map to OpName. Default is "" which
@@ -69,37 +69,31 @@ struct DXILOperationDesc {
int OverloadParamIndex; // parameter index which control the overload.
// When < 0, should be only 1 overload type.
SmallVector<StringRef, 4> counters; // counters for this inst.
- DXILOperationDesc(const Record *R) {
- OpName = R->getValueAsString("OpName");
- OpCode = R->getValueAsInt("OpCode");
- OpClass = R->getValueAsDef("OpClass")->getValueAsString("Name");
- Category = R->getValueAsDef("OpCategory")->getValueAsString("Name");
-
- if (R->getValue("llvm_intrinsic")) {
- auto *IntrinsicDef = R->getValueAsDef("llvm_intrinsic");
- auto DefName = IntrinsicDef->getName();
- assert(DefName.starts_with("int_") && "invalid intrinsic name");
- // Remove the int_ from intrinsic name.
- Intrinsic = DefName.substr(4);
- }
-
- Doc = R->getValueAsString("Doc");
-
- ListInit *ParamList = R->getValueAsListInit("Params");
- OverloadParamIndex = -1;
- for (unsigned I = 0; I < ParamList->size(); ++I) {
- Record *Param = ParamList->getElementAsRecord(I);
- Params.emplace_back(DXILParameter(Param));
- auto &CurParam = Params.back();
- if (CurParam.Kind >= ParameterKind::OVERLOAD)
- OverloadParamIndex = I;
- }
- OverloadTypes = R->getValueAsString("OverloadTypes");
- FnAttr = R->getValueAsString("Attributes");
- }
+ DXILOperationDesc(const Record *);
};
} // end anonymous namespace
+// Convert DXIL type name string to dxil::ParameterKind
+// @param typeNameStr Type name string
+// @return ParameterKind as defined in llvm/Support/DXILABI.h
+static ParameterKind getDXILTypeNameToKind(StringRef typeNameStr) {
+ return StringSwitch<ParameterKind>(typeNameStr)
+ .Case("voidTy", ParameterKind::VOID)
+ .Case("f16Ty", ParameterKind::HALF)
+ .Case("f32Ty", ParameterKind::FLOAT)
+ .Case("f64Ty", ParameterKind::DOUBLE)
+ .Case("i1Ty", ParameterKind::I1)
+ .Case("i8Ty", ParameterKind::I8)
+ .Case("i16Ty", ParameterKind::I16)
+ .Case("i32Ty", ParameterKind::I32)
+ .Case("i64Ty", ParameterKind::I64)
+ .Case("overloadTy", ParameterKind::OVERLOAD)
+ .Case("handleTy", ParameterKind::DXIL_HANDLE)
+ .Case("cbufferRetTy", ParameterKind::CBUFFER_RET)
+ .Case("resourceRetTy", ParameterKind::RESOURCE_RET)
+ .Default(ParameterKind::INVALID);
+}
+
static ParameterKind parameterTypeNameToKind(StringRef Name) {
return StringSwitch<ParameterKind>(Name)
.Case("void", ParameterKind::VOID)
@@ -118,10 +112,44 @@ static ParameterKind parameterTypeNameToKind(StringRef Name) {
.Default(ParameterKind::INVALID);
}
+DXILOperationDesc::DXILOperationDesc(const Record *R) {
+ OpName = R->getValueAsString("OpName");
+ OpCode = R->getValueAsInt("OpCode");
+ OpClass = R->getValueAsDef("OpClass")->getValueAsString("Name");
+ Category = R->getValueAsDef("OpCategory")->getValueAsString("Name");
+
+ if (R->getValue("llvm_intrinsic")) {
+ auto *IntrinsicDef = R->getValueAsDef("llvm_intrinsic");
+ auto DefName = IntrinsicDef->getName();
+ assert(DefName.starts_with("int_") && "invalid intrinsic name");
+ // Remove the int_ from intrinsic name.
+ Intrinsic = DefName.substr(4);
+ }
+
+ Doc = R->getValueAsString("Doc");
+
+ ListInit *ParamList = R->getValueAsListInit("Params");
+ OverloadParamIndex = -1;
+ for (unsigned I = 0; I < ParamList->size(); ++I) {
+ Record *Param = ParamList->getElementAsRecord(I);
+ Params.emplace_back(DXILParameter(Param));
+ auto &CurParam = Params.back();
+ if (CurParam.Kind >= ParameterKind::OVERLOAD)
+ OverloadParamIndex = I;
+ }
+ ListInit *OverloadTypeList = R->getValueAsListInit("OverloadTypes");
+
+ for (unsigned I = 0; I < OverloadTypeList->size(); ++I) {
+ Record *R = OverloadTypeList->getElementAsRecord(I);
+ OverloadTypes.emplace_back(getDXILTypeNameToKind(R->getNameInitAsString()));
+ }
+ FnAttr = R->getValueAsString("Attributes");
+}
+
DXILParameter::DXILParameter(const Record *R) {
Name = R->getValueAsString("Name");
Pos = R->getValueAsInt("Pos");
- Kind = parameterTypeNameToKind(R->getValueAsString("LLVMType"));
+ Kind = parameterTypeNameToKind(R->getValueAsString("Type"));
if (R->getValue("Doc"))
Doc = R->getValueAsString("Doc");
IsConst = R->getValueAsBit("IsConstant");
@@ -268,36 +296,45 @@ static void emitDXILIntrinsicMap(std::vector<DXILOperationDesc> &Ops,
static std::string emitDXILOperationFnAttr(StringRef FnAttr) {
return StringSwitch<std::string>(FnAttr)
- .Case("rn", "Attribute::ReadNone")
- .Case("ro", "Attribute::ReadOnly")
+ .Case("NoReadMemory", "Attribute::ReadNone")
+ .Case("ReadMemory", "Attribute::ReadOnly")
.Default("Attribute::None");
}
-static std::string getOverloadKind(StringRef Overload) {
- return StringSwitch<std::string>(Overload)
- .Case("half", "OverloadKind::HALF")
- .Case("float", "OverloadKind::FLOAT")
- .Case("double", "OverloadKind::DOUBLE")
- .Case("i1", "OverloadKind::I1")
- .Case("i16", "OverloadKind::I16")
- .Case("i32", "OverloadKind::I32")
- .Case("i64", "OverloadKind::I64")
- .Case("udt", "OverloadKind::UserDefineType")
- .Case("obj", "OverloadKind::ObjectType")
- .Default("OverloadKind::VOID");
+static std::string overloadKindStr(ParameterKind Overload) {
+ switch (Overload) {
+ case ParameterKind::HALF:
+ return "OverloadKind::HALF";
+ case ParameterKind::FLOAT:
+ return "OverloadKind::FLOAT";
+ case ParameterKind::DOUBLE:
+ return "OverloadKind::DOUBLE";
+ case ParameterKind::I1:
+ return "OverloadKind::I1";
+ case ParameterKind::I8:
+ return "OverloadKind::I8";
+ case ParameterKind::I16:
+ return "OverloadKind::I16";
+ case ParameterKind::I32:
+ return "OverloadKind::I32";
+ case ParameterKind::I64:
+ return "OverloadKind::I64";
+ case ParameterKind::VOID:
+ return "OverloadKind::VOID";
+ default:
+ return "OverloadKind::UNKNOWN";
+ }
}
-static std::string getDXILOperationOverload(StringRef Overloads) {
- SmallVector<StringRef> OverloadStrs;
- Overloads.split(OverloadStrs, ';', /*MaxSplit*/ -1, /*KeepEmpty*/ false);
+static std::string
+getDXILOperationOverloads(SmallVector<ParameterKind> Overloads) {
// Format is: OverloadKind::FLOAT | OverloadKind::HALF
- assert(!OverloadStrs.empty() && "Invalid overloads");
- auto It = OverloadStrs.begin();
+ auto It = Overloads.begin();
std::string Result;
raw_string_ostream OS(Result);
- OS << getOverloadKind(*It);
- for (++It; It != OverloadStrs.end(); ++It) {
- OS << " | " << getOverloadKind(*It);
+ OS << overloadKindStr(*It);
+ for (++It; It != Overloads.end(); ++It) {
+ OS << " | " << overloadKindStr(*It);
}
return OS.str();
}
@@ -367,7 +404,7 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
OS << " { dxil::OpCode::" << Op.OpName << ", "
<< OpStrings.get(Op.OpName.str()) << ", OpCodeClass::" << Op.OpClass
<< ", " << OpClassStrings.get(getDXILOpClassName(Op.OpClass)) << ", "
- << getDXILOperationOverload(Op.OverloadTypes) << ", "
+ << getDXILOperationOverloads(Op.OverloadTypes) << ", "
<< emitDXILOperationFnAttr(Op.FnAttr) << ", " << Op.OverloadParamIndex
<< ", " << Op.Params.size() << ", "
<< Parameters.get(ParameterMap[Op.OpClass]) << " },\n";
``````````
</details>
https://github.com/llvm/llvm-project/pull/81184
More information about the llvm-commits
mailing list