[llvm] b1c8b9f - [DirectX][NFC] Leverage LLVM and DirectX intrinsic description in DXIL Op records (#83193)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 29 06:21:49 PST 2024
Author: S. Bharadwaj Yadavalli
Date: 2024-02-29T06:21:44-08:00
New Revision: b1c8b9f89cac91db857b9123838ac6b6aeb0ae74
URL: https://github.com/llvm/llvm-project/commit/b1c8b9f89cac91db857b9123838ac6b6aeb0ae74
DIFF: https://github.com/llvm/llvm-project/commit/b1c8b9f89cac91db857b9123838ac6b6aeb0ae74.diff
LOG: [DirectX][NFC] Leverage LLVM and DirectX intrinsic description in DXIL Op records (#83193)
* Leverage TableGen record descriptions of LLVM or DirectX intrinsics
that can be directly mapped in DXIL Ops TableGen description. As a
result, such DXIL Ops can be succinctly described without duplication.
DXILEmitter backend can derive the properties of DXIL Ops accordingly.
* Ensured that corresponding lit tests pass.
Added:
Modified:
llvm/lib/Target/DirectX/DXIL.td
llvm/lib/Target/DirectX/DXILOpBuilder.cpp
llvm/lib/Target/DirectX/DXILOpBuilder.h
llvm/lib/Target/DirectX/DXILOpLowering.cpp
llvm/utils/TableGen/DXILEmitter.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 8a3454c89542ce..67ef7986622092 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -12,139 +12,224 @@
//===----------------------------------------------------------------------===//
include "llvm/IR/Intrinsics.td"
-include "llvm/IR/Attributes.td"
-// Abstract representation of the class a DXIL Operation belongs to.
-class DXILOpClass<string name> {
- string Name = name;
+class DXILOpClass;
+
+// Following is a set of DXIL Operation classes whose names appear to be
+// arbitrary, yet need to be a substring of the function name used during
+// lowering to DXIL Operation calls. These class name strings are specified
+// as the third argument of add_dixil_op in utils/hct/hctdb.py and case converted
+// in utils/hct/hctdb_instrhelp.py of DirectXShaderCompiler repo. The function
+// name has the format "dx.op.<class-name>.<return-type>".
+
+defset list<DXILOpClass> OpClasses = {
+ def acceptHitAndEndSearch : DXILOpClass;
+ def allocateNodeOutputRecords : DXILOpClass;
+ def allocateRayQuery : DXILOpClass;
+ def annotateHandle : DXILOpClass;
+ def annotateNodeHandle : DXILOpClass;
+ def annotateNodeRecordHandle : DXILOpClass;
+ def atomicBinOp : DXILOpClass;
+ def atomicCompareExchange : DXILOpClass;
+ def attributeAtVertex : DXILOpClass;
+ def barrier : DXILOpClass;
+ def barrierByMemoryHandle : DXILOpClass;
+ def barrierByMemoryType : DXILOpClass;
+ def barrierByNodeRecordHandle : DXILOpClass;
+ def binary : DXILOpClass;
+ def binaryWithCarryOrBorrow : DXILOpClass;
+ def binaryWithTwoOuts : DXILOpClass;
+ def bitcastF16toI16 : DXILOpClass;
+ def bitcastF32toI32 : DXILOpClass;
+ def bitcastF64toI64 : DXILOpClass;
+ def bitcastI16toF16 : DXILOpClass;
+ def bitcastI32toF32 : DXILOpClass;
+ def bitcastI64toF64 : DXILOpClass;
+ def bufferLoad : DXILOpClass;
+ def bufferStore : DXILOpClass;
+ def bufferUpdateCounter : DXILOpClass;
+ def calculateLOD : DXILOpClass;
+ def callShader : DXILOpClass;
+ def cbufferLoad : DXILOpClass;
+ def cbufferLoadLegacy : DXILOpClass;
+ def checkAccessFullyMapped : DXILOpClass;
+ def coverage : DXILOpClass;
+ def createHandle : DXILOpClass;
+ def createHandleForLib : DXILOpClass;
+ def createHandleFromBinding : DXILOpClass;
+ def createHandleFromHeap : DXILOpClass;
+ def createNodeInputRecordHandle : DXILOpClass;
+ def createNodeOutputHandle : DXILOpClass;
+ def cutStream : DXILOpClass;
+ def cycleCounterLegacy : DXILOpClass;
+ def discard : DXILOpClass;
+ def dispatchMesh : DXILOpClass;
+ def dispatchRaysDimensions : DXILOpClass;
+ def dispatchRaysIndex : DXILOpClass;
+ def domainLocation : DXILOpClass;
+ def dot2 : DXILOpClass;
+ def dot2AddHalf : DXILOpClass;
+ def dot3 : DXILOpClass;
+ def dot4 : DXILOpClass;
+ def dot4AddPacked : DXILOpClass;
+ def emitIndices : DXILOpClass;
+ def emitStream : DXILOpClass;
+ def emitThenCutStream : DXILOpClass;
+ def evalCentroid : DXILOpClass;
+ def evalSampleIndex : DXILOpClass;
+ def evalSnapped : DXILOpClass;
+ def finishedCrossGroupSharing : DXILOpClass;
+ def flattenedThreadIdInGroup : DXILOpClass;
+ def geometryIndex : DXILOpClass;
+ def getDimensions : DXILOpClass;
+ def getInputRecordCount : DXILOpClass;
+ def getMeshPayload : DXILOpClass;
+ def getNodeRecordPtr : DXILOpClass;
+ def getRemainingRecursionLevels : DXILOpClass;
+ def groupId : DXILOpClass;
+ def gsInstanceID : DXILOpClass;
+ def hitKind : DXILOpClass;
+ def ignoreHit : DXILOpClass;
+ def incrementOutputCount : DXILOpClass;
+ def indexNodeHandle : DXILOpClass;
+ def innerCoverage : DXILOpClass;
+ def instanceID : DXILOpClass;
+ def instanceIndex : DXILOpClass;
+ def isHelperLane : DXILOpClass;
+ def isSpecialFloat : DXILOpClass;
+ def legacyDoubleToFloat : DXILOpClass;
+ def legacyDoubleToSInt32 : DXILOpClass;
+ def legacyDoubleToUInt32 : DXILOpClass;
+ def legacyF16ToF32 : DXILOpClass;
+ def legacyF32ToF16 : DXILOpClass;
+ def loadInput : DXILOpClass;
+ def loadOutputControlPoint : DXILOpClass;
+ def loadPatchConstant : DXILOpClass;
+ def makeDouble : DXILOpClass;
+ def minPrecXRegLoad : DXILOpClass;
+ def minPrecXRegStore : DXILOpClass;
+ def nodeOutputIsValid : DXILOpClass;
+ def objectRayDirection : DXILOpClass;
+ def objectRayOrigin : DXILOpClass;
+ def objectToWorld : DXILOpClass;
+ def outputComplete : DXILOpClass;
+ def outputControlPointID : DXILOpClass;
+ def pack4x8 : DXILOpClass;
+ def primitiveID : DXILOpClass;
+ def primitiveIndex : DXILOpClass;
+ def quadOp : DXILOpClass;
+ def quadReadLaneAt : DXILOpClass;
+ def quadVote : DXILOpClass;
+ def quaternary : DXILOpClass;
+ def rawBufferLoad : DXILOpClass;
+ def rawBufferStore : DXILOpClass;
+ def rayFlags : DXILOpClass;
+ def rayQuery_Abort : DXILOpClass;
+ def rayQuery_CommitNonOpaqueTriangleHit : DXILOpClass;
+ def rayQuery_CommitProceduralPrimitiveHit : DXILOpClass;
+ def rayQuery_Proceed : DXILOpClass;
+ def rayQuery_StateMatrix : DXILOpClass;
+ def rayQuery_StateScalar : DXILOpClass;
+ def rayQuery_StateVector : DXILOpClass;
+ def rayQuery_TraceRayInline : DXILOpClass;
+ def rayTCurrent : DXILOpClass;
+ def rayTMin : DXILOpClass;
+ def renderTargetGetSampleCount : DXILOpClass;
+ def renderTargetGetSamplePosition : DXILOpClass;
+ def reportHit : DXILOpClass;
+ def sample : DXILOpClass;
+ def sampleBias : DXILOpClass;
+ def sampleCmp : DXILOpClass;
+ def sampleCmpBias : DXILOpClass;
+ def sampleCmpGrad : DXILOpClass;
+ def sampleCmpLevel : DXILOpClass;
+ def sampleCmpLevelZero : DXILOpClass;
+ def sampleGrad : DXILOpClass;
+ def sampleIndex : DXILOpClass;
+ def sampleLevel : DXILOpClass;
+ def setMeshOutputCounts : DXILOpClass;
+ def splitDouble : DXILOpClass;
+ def startInstanceLocation : DXILOpClass;
+ def startVertexLocation : DXILOpClass;
+ def storeOutput : DXILOpClass;
+ def storePatchConstant : DXILOpClass;
+ def storePrimitiveOutput : DXILOpClass;
+ def storeVertexOutput : DXILOpClass;
+ def tempRegLoad : DXILOpClass;
+ def tempRegStore : DXILOpClass;
+ def tertiary : DXILOpClass;
+ def texture2DMSGetSamplePosition : DXILOpClass;
+ def textureGather : DXILOpClass;
+ def textureGatherCmp : DXILOpClass;
+ def textureGatherRaw : DXILOpClass;
+ def textureLoad : DXILOpClass;
+ def textureStore : DXILOpClass;
+ def textureStoreSample : DXILOpClass;
+ def threadId : DXILOpClass;
+ def threadIdInGroup : DXILOpClass;
+ def traceRay : DXILOpClass;
+ def unary : DXILOpClass;
+ def unaryBits : DXILOpClass;
+ def unpack4x8 : DXILOpClass;
+ def viewID : DXILOpClass;
+ def waveActiveAllEqual : DXILOpClass;
+ def waveActiveBallot : DXILOpClass;
+ def waveActiveBit : DXILOpClass;
+ def waveActiveOp : DXILOpClass;
+ def waveAllOp : DXILOpClass;
+ def waveAllTrue : DXILOpClass;
+ def waveAnyTrue : DXILOpClass;
+ def waveGetLaneCount : DXILOpClass;
+ def waveGetLaneIndex : DXILOpClass;
+ def waveIsFirstLane : DXILOpClass;
+ def waveMatch : DXILOpClass;
+ def waveMatrix_Accumulate : DXILOpClass;
+ def waveMatrix_Annotate : DXILOpClass;
+ def waveMatrix_Depth : DXILOpClass;
+ def waveMatrix_Fill : DXILOpClass;
+ def waveMatrix_LoadGroupShared : DXILOpClass;
+ def waveMatrix_LoadRawBuf : DXILOpClass;
+ def waveMatrix_Multiply : DXILOpClass;
+ def waveMatrix_ScalarOp : DXILOpClass;
+ def waveMatrix_StoreGroupShared : DXILOpClass;
+ def waveMatrix_StoreRawBuf : DXILOpClass;
+ def waveMultiPrefixBitCount : DXILOpClass;
+ def waveMultiPrefixOp : DXILOpClass;
+ def wavePrefixOp : DXILOpClass;
+ def waveReadLaneAt : DXILOpClass;
+ def waveReadLaneFirst : DXILOpClass;
+ def worldRayDirection : DXILOpClass;
+ def worldRayOrigin : DXILOpClass;
+ def worldToObject : DXILOpClass;
+ def writeSamplerFeedback : DXILOpClass;
+ def writeSamplerFeedbackBias : DXILOpClass;
+ def writeSamplerFeedbackGrad : DXILOpClass;
+ def writeSamplerFeedbackLevel: DXILOpClass;
}
-// Abstract representation of the category a DXIL Operation belongs to
-class DXILOpCategory<string name> {
- string Name = name;
+// Abstraction DXIL Operation to LLVM intrinsic
+class DXILOpMapping<int opCode, DXILOpClass opClass, Intrinsic intrinsic, string doc> {
+ int OpCode = opCode; // Opcode corresponding to DXIL Operation
+ DXILOpClass OpClass = opClass; // Class of DXIL Operation.
+ Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
+ string Doc = doc; // to a short description of the operation
}
-def UnaryClass : DXILOpClass<"Unary">;
-def BinaryClass : DXILOpClass<"Binary">;
-def FlattenedThreadIdInGroupClass : DXILOpClass<"FlattenedThreadIdInGroup">;
-def ThreadIdInGroupClass : DXILOpClass<"ThreadIdInGroup">;
-def ThreadIdClass : DXILOpClass<"ThreadId">;
-def GroupIdClass : DXILOpClass<"GroupId">;
-
-def BinaryUintCategory : DXILOpCategory<"Binary uint">;
-def UnaryFloatCategory : DXILOpCategory<"Unary float">;
-def ComputeIDCategory : DXILOpCategory<"Compute/Mesh/Amplification shader">;
-
-// Represent as any pointer type with an option to change to a qualified pointer
-// type with address space specified.
-def dxil_handle_ty : LLVMAnyPointerType;
-def dxil_cbuffer_ty : LLVMAnyPointerType;
-def dxil_resource_ty : LLVMAnyPointerType;
-
-// The parameter description for a DXIL operation
-class DXILOpParameter<int pos, LLVMType type, string name, string doc,
- bit isConstant = 0, string enumName = "",
- int maxValue = 0> {
- int Pos = pos; // Position in parameter list
- LLVMType ParamType = type; // Parameter type
- string Name = name; // Short, unique parameter name
- string Doc = doc; // Description of this parameter
- bit IsConstant = isConstant; // Whether this parameter requires a constant value in the IR
- string EnumName = enumName; // Name of the enum type, if applicable
- int MaxValue = maxValue; // Maximum value for this parameter, if applicable
-}
-
-// A representation for a DXIL operation
-class DXILOperationDesc {
- string OpName = ""; // Name of DXIL operation
- int OpCode = 0; // Unique non-negative integer associated with the operation
- DXILOpClass OpClass; // Class of the operation
- DXILOpCategory OpCategory; // Category of the operation
- string Doc = ""; // Description of the operation
- list<DXILOpParameter> Params = []; // Parameter list of the operation
- list<LLVMType> OverloadTypes = []; // Overload types, if applicable
- EnumAttr Attribute; // Operation Attribute. Leverage attributes defined in Attributes.td
- // ReadNone - operation does not access memory.
- // ReadOnly - only reads from memory.
- // "ReadMemory" - reads memory
- bit IsDerivative = 0; // Whether this is some kind of derivative
- bit IsGradient = 0; // Whether this requires a gradient calculation
- bit IsFeedback = 0; // Whether this is a sampler feedback operation
- bit IsWave = 0; // Whether this requires in-wave, cross-lane functionality
- bit NeedsUniformInputs = 0; // Whether this operation requires that all
- // of its inputs are uniform across the wave
- // Group DXIL operation for stats - e.g., to accumulate the number of atomic/float/uint/int/...
- // operations used in the program.
- list<string> StatsGroup = [];
-}
-
-class DXILOperation<string name, int opCode, DXILOpClass opClass, DXILOpCategory opCategory, string doc,
- list<LLVMType> oloadTypes, EnumAttr attrs, list<DXILOpParameter> params,
- list<string> statsGroup = []> : DXILOperationDesc {
- let OpName = name;
- let OpCode = opCode;
- let Doc = doc;
- let Params = params;
- let OpClass = opClass;
- let OpCategory = opCategory;
- let OverloadTypes = oloadTypes;
- let Attribute = attrs;
- let StatsGroup = statsGroup;
-}
-
-// LLVM intrinsic that DXIL operation maps to.
-class LLVMIntrinsic<Intrinsic llvm_intrinsic_> { Intrinsic llvm_intrinsic = llvm_intrinsic_; }
-
-def Sin : DXILOperation<"Sin", 13, UnaryClass, UnaryFloatCategory, "returns sine(theta) for theta in radians.",
- [llvm_half_ty, llvm_float_ty], ReadNone,
- [
- DXILOpParameter<0, llvm_anyfloat_ty, "", "operation result">,
- DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
- DXILOpParameter<2, llvm_anyfloat_ty, "value", "input value">
- ],
- ["floats"]>,
- LLVMIntrinsic<int_sin>;
-
-def UMax : DXILOperation< "UMax", 39, BinaryClass, BinaryUintCategory, "unsigned integer maximum. UMax(a,b) = a > b ? a : b",
- [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty], ReadNone,
- [
- DXILOpParameter<0, llvm_anyint_ty, "", "operation result">,
- DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
- DXILOpParameter<2, llvm_anyint_ty, "a", "input value">,
- DXILOpParameter<3, llvm_anyint_ty, "b", "input value">
- ],
- ["uints"]>,
- LLVMIntrinsic<int_umax>;
-
-def ThreadId : DXILOperation< "ThreadId", 93, ThreadIdClass, ComputeIDCategory, "reads the thread ID", [llvm_i32_ty], ReadNone,
- [
- DXILOpParameter<0, llvm_i32_ty, "", "thread ID component">,
- DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
- DXILOpParameter<2, llvm_i32_ty, "component", "component to read (x,y,z)">
- ]>,
- LLVMIntrinsic<int_dx_thread_id>;
-
-def GroupId : DXILOperation< "GroupId", 94, GroupIdClass, ComputeIDCategory, "reads the group ID (SV_GroupID)", [llvm_i32_ty], ReadNone,
- [
- DXILOpParameter<0, llvm_i32_ty, "", "group ID component">,
- DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
- DXILOpParameter<2, llvm_i32_ty, "component", "component to read">
- ]>,
- LLVMIntrinsic<int_dx_group_id>;
-
-def ThreadIdInGroup : DXILOperation< "ThreadIdInGroup", 95, ThreadIdInGroupClass, ComputeIDCategory,
- "reads the thread ID within the group (SV_GroupThreadID)", [llvm_i32_ty], ReadNone,
- [
- DXILOpParameter<0, llvm_i32_ty, "", "thread ID in group component">,
- DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
- DXILOpParameter<2, llvm_i32_ty, "component", "component to read (x,y,z)">
- ]>,
- LLVMIntrinsic<int_dx_thread_id_in_group>;
-
-def FlattenedThreadIdInGroup : DXILOperation< "FlattenedThreadIdInGroup", 96, FlattenedThreadIdInGroupClass, ComputeIDCategory,
- "provides a flattened index for a given thread within a given group (SV_GroupIndex)", [llvm_i32_ty], ReadNone,
- [
- DXILOpParameter<0, llvm_i32_ty, "", "result">,
- DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">
- ]>,
- LLVMIntrinsic<int_dx_flattened_thread_id_in_group>;
+// Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
+def Sin : DXILOpMapping<13, unary, int_sin,
+ "Returns sine(theta) for theta in radians.">;
+def UMax : DXILOpMapping<39, binary, int_umax,
+ "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
+def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
+ "Reads the thread ID">;
+def GroupId : DXILOpMapping<94, groupId, int_dx_group_id,
+ "Reads the group ID (SV_GroupID)">;
+def ThreadIdInGroup : DXILOpMapping<95, threadIdInGroup,
+ int_dx_thread_id_in_group,
+ "Reads the thread ID within the group "
+ "(SV_GroupThreadID)">;
+def FlattenedThreadIdInGroup : DXILOpMapping<96, flattenedThreadIdInGroup,
+ int_dx_flattened_thread_id_in_group,
+ "Provides a flattened index for a "
+ "given thread within a given "
+ "group (SV_GroupIndex)">;
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 42180a865b72e3..21a20d45b922d9 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -221,12 +221,26 @@ static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
return nullptr;
}
+/// Construct DXIL function type. This is the type of a function with
+/// the following prototype
+/// OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)
+/// <param-types> are constructed from types in Prop.
+/// \param Prop Structure containing DXIL Operation properties based on
+/// its specification in DXIL.td.
+/// \param OverloadTy Return type to be used to construct DXIL function type.
static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
Type *OverloadTy) {
SmallVector<Type *> ArgTys;
auto ParamKinds = getOpCodeParameterKind(*Prop);
+ // Add OverloadTy as return type of the function
+ ArgTys.emplace_back(OverloadTy);
+
+ // Add DXIL Opcode value type viz., Int32 as first argument
+ ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext()));
+
+ // Add DXIL Operation parameter types as specified in DXIL properties
for (unsigned I = 0; I < Prop->NumOfParameters; ++I) {
ParameterKind Kind = ParamKinds[I];
ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy));
@@ -267,13 +281,13 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
return B.CreateCall(Fn, FullArgs);
}
-Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT,
- bool NoOpCodeParam) {
+Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
+ // If DXIL Op has no overload parameter, just return the
+ // precise return type specified.
if (Prop->OverloadParamIndex < 0) {
auto &Ctx = FT->getContext();
- // When only has 1 overload type, just return it.
switch (Prop->OverloadTys) {
case OverloadKind::VOID:
return Type::getVoidTy(Ctx);
@@ -302,9 +316,8 @@ Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT,
// Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
Type *OverloadType = FT->getReturnType();
if (Prop->OverloadParamIndex != 0) {
- // Skip Return Type and Type for DXIL opcode.
- const unsigned SkipedParam = NoOpCodeParam ? 2 : 1;
- OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam);
+ // Skip Return Type.
+ OverloadType = FT->getParamType(Prop->OverloadParamIndex - 1);
}
auto ParamKinds = getOpCodeParameterKind(*Prop);
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 940ed538c7ce15..1c15f109184adf 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -31,8 +31,7 @@ class DXILOpBuilder {
DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {}
CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
llvm::iterator_range<Use *> Args);
- Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT,
- bool NoOpCodeParam);
+ Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
static const char *getOpCodeName(dxil::OpCode DXILOp);
private:
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index f6e2297e9af41f..6b649b76beecdf 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -33,8 +33,7 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
IRBuilder<> B(M.getContext());
Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
DXILOpBuilder DXILB(M, B);
- Type *OverloadTy =
- DXILB.getOverloadTy(DXILOp, F.getFunctionType(), /*NoOpCodeParam*/ true);
+ Type *OverloadTy = DXILB.getOverloadTy(DXILOp, F.getFunctionType());
for (User *U : make_early_inc_range(F.users())) {
CallInst *CI = dyn_cast<CallInst>(U);
if (!CI)
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index d47df597d53a35..fc958f5328736c 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -11,11 +11,14 @@
//
//===----------------------------------------------------------------------===//
+#include "CodeGenTarget.h"
#include "SequenceToOffsetTable.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/StringSwitch.h"
+#include "llvm/CodeGenTypes/MachineValueType.h"
#include "llvm/Support/DXILABI.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
@@ -30,28 +33,15 @@ struct DXILShaderModel {
int Minor = 0;
};
-struct DXILParameter {
- int Pos; // position in parameter list
- ParameterKind Kind;
- StringRef Name; // short, unique name
- StringRef Doc; // the documentation description of this parameter
- bool IsConst; // whether this argument requires a constant value in the IR
- StringRef EnumName; // the name of the enum type if applicable
- int MaxValue; // the maximum value for this parameter if applicable
- DXILParameter(const Record *R);
-};
-
struct DXILOperationDesc {
- StringRef OpName; // name of DXIL operation
+ std::string OpName; // name of DXIL operation
int OpCode; // ID of DXIL operation
StringRef OpClass; // name of the opcode class
- StringRef Category; // classification for this instruction
StringRef Doc; // the documentation description of this instruction
-
- SmallVector<DXILParameter> Params; // the operands that this instruction takes
- SmallVector<ParameterKind> OverloadTypes; // overload types if applicable
- StringRef Attr; // operation attribute; reference to string representation
- // of llvm::Attribute::AttrKind
+ SmallVector<MVT::SimpleValueType> OpTypes; // Vector of operand types -
+ // return type is at index 0
+ SmallVector<std::string>
+ OpAttributes; // operation attribute represented as strings
StringRef Intrinsic; // The llvm intrinsic map to OpName. Default is "" which
// means no map exists
bool IsDeriv = false; // whether this is some kind of derivative
@@ -74,81 +64,99 @@ struct DXILOperationDesc {
};
} // end anonymous namespace
-/*!
- Convert DXIL type name string to dxil::ParameterKind
-
- @param typeNameStr Type name string
- @return ParameterKind As defined in llvm/Support/DXILABI.h
-*/
-static ParameterKind lookupParameterKind(StringRef typeNameStr) {
- auto paramKind = StringSwitch<ParameterKind>(typeNameStr)
- .Case("llvm_void_ty", ParameterKind::VOID)
- .Case("llvm_half_ty", ParameterKind::HALF)
- .Case("llvm_float_ty", ParameterKind::FLOAT)
- .Case("llvm_double_ty", ParameterKind::DOUBLE)
- .Case("llvm_i1_ty", ParameterKind::I1)
- .Case("llvm_i8_ty", ParameterKind::I8)
- .Case("llvm_i16_ty", ParameterKind::I16)
- .Case("llvm_i32_ty", ParameterKind::I32)
- .Case("llvm_i64_ty", ParameterKind::I64)
- .Case("llvm_anyfloat_ty", ParameterKind::OVERLOAD)
- .Case("llvm_anyint_ty", ParameterKind::OVERLOAD)
- .Case("dxil_handle_ty", ParameterKind::DXIL_HANDLE)
- .Case("dxil_cbuffer_ty", ParameterKind::CBUFFER_RET)
- .Case("dxil_resource_ty", ParameterKind::RESOURCE_RET)
- .Default(ParameterKind::INVALID);
- assert(paramKind != ParameterKind::INVALID &&
- "Unsupported DXIL Type specified");
- return paramKind;
+/// Convert DXIL type name string to dxil::ParameterKind
+///
+/// \param VT Simple Value Type
+/// \return ParameterKind As defined in llvm/Support/DXILABI.h
+
+static ParameterKind getParameterKind(MVT::SimpleValueType VT) {
+ switch (VT) {
+ case MVT::isVoid:
+ return ParameterKind::VOID;
+ case MVT::f16:
+ return ParameterKind::HALF;
+ case MVT::f32:
+ return ParameterKind::FLOAT;
+ case MVT::f64:
+ return ParameterKind::DOUBLE;
+ case MVT::i1:
+ return ParameterKind::I1;
+ case MVT::i8:
+ return ParameterKind::I8;
+ case MVT::i16:
+ return ParameterKind::I16;
+ case MVT::i32:
+ return ParameterKind::I32;
+ case MVT::fAny:
+ case MVT::iAny:
+ return ParameterKind::OVERLOAD;
+ default:
+ llvm_unreachable("Support for specified DXIL Type not yet implemented");
+ }
}
+/// Construct an object using the DXIL Operation records specified
+/// in DXIL.td. This serves as the single source of reference of
+/// the information extracted from the specified Record R, for
+/// C++ code generated by this TableGen backend.
+// \param R Object representing TableGen record of a DXIL Operation
DXILOperationDesc::DXILOperationDesc(const Record *R) {
- OpName = R->getValueAsString("OpName");
+ OpName = R->getNameInitAsString();
OpCode = R->getValueAsInt("OpCode");
- OpClass = R->getValueAsDef("OpClass")->getValueAsString("Name");
- Category = R->getValueAsDef("OpCategory")->getValueAsString("Name");
- if (R->getValue("llvm_intrinsic")) {
- auto *IntrinsicDef = R->getValueAsDef("llvm_intrinsic");
+ Doc = R->getValueAsString("Doc");
+
+ if (R->getValue("LLVMIntrinsic")) {
+ auto *IntrinsicDef = R->getValueAsDef("LLVMIntrinsic");
auto DefName = IntrinsicDef->getName();
assert(DefName.starts_with("int_") && "invalid intrinsic name");
// Remove the int_ from intrinsic name.
Intrinsic = DefName.substr(4);
+ // TODO: It is expected that return type and parameter types of
+ // DXIL Operation are the same as that of the intrinsic. Deviations
+ // are expected to be encoded in TableGen record specification and
+ // handled accordingly here. Support to be added later, as needed.
+ // Get parameter type list of the intrinsic. Types attribute contains
+ // the list of as [returnType, param1Type,, param2Type, ...]
+
+ OverloadParamIndex = -1;
+ auto TypeRecs = IntrinsicDef->getValueAsListOfDefs("Types");
+ unsigned TypeRecsSize = TypeRecs.size();
+ // Populate return type and parameter type names
+ for (unsigned i = 0; i < TypeRecsSize; i++) {
+ auto TR = TypeRecs[i];
+ OpTypes.emplace_back(getValueType(TR->getValueAsDef("VT")));
+ // Get the overload parameter index.
+ // TODO : Seems hacky. Is it possible that more than one parameter can
+ // be of overload kind??
+ // TODO: Check for any additional constraints specified for DXIL operation
+ // restricting return type.
+ if (i > 0) {
+ auto &CurParam = OpTypes.back();
+ if (getParameterKind(CurParam) >= ParameterKind::OVERLOAD) {
+ OverloadParamIndex = i;
+ }
+ }
+ }
+ // Get the operation class
+ OpClass = R->getValueAsDef("OpClass")->getName();
+
+ // NOTE: For now, assume that attributes of DXIL Operation are the same as
+ // that of the intrinsic. Deviations are expected to be encoded in TableGen
+ // record specification and handled accordingly here. Support to be added
+ // later.
+ auto IntrPropList = IntrinsicDef->getValueAsListInit("IntrProperties");
+ auto IntrPropListSize = IntrPropList->size();
+ for (unsigned i = 0; i < IntrPropListSize; i++) {
+ OpAttributes.emplace_back(IntrPropList->getElement(i)->getAsString());
+ }
}
-
- Doc = R->getValueAsString("Doc");
-
- ListInit *ParamList = R->getValueAsListInit("Params");
- OverloadParamIndex = -1;
- for (unsigned I = 0; I < ParamList->size(); ++I) {
- Record *Param = ParamList->getElementAsRecord(I);
- Params.emplace_back(DXILParameter(Param));
- auto &CurParam = Params.back();
- if (CurParam.Kind >= ParameterKind::OVERLOAD)
- OverloadParamIndex = I;
- }
- ListInit *OverloadTypeList = R->getValueAsListInit("OverloadTypes");
-
- for (unsigned I = 0; I < OverloadTypeList->size(); ++I) {
- Record *R = OverloadTypeList->getElementAsRecord(I);
- OverloadTypes.emplace_back(lookupParameterKind(R->getNameInitAsString()));
- }
- Attr = StringRef(R->getValue("Attribute")->getNameInitAsString());
}
-DXILParameter::DXILParameter(const Record *R) {
- Name = R->getValueAsString("Name");
- Pos = R->getValueAsInt("Pos");
- Kind =
- lookupParameterKind(R->getValue("ParamType")->getValue()->getAsString());
- if (R->getValue("Doc"))
- Doc = R->getValueAsString("Doc");
- IsConst = R->getValueAsBit("IsConstant");
- EnumName = R->getValueAsString("EnumName");
- MaxValue = R->getValueAsInt("MaxValue");
-}
-
-static std::string parameterKindToString(ParameterKind Kind) {
+/// Return a string representation of ParameterKind enum
+/// \param Kind Parameter Kind enum value
+/// \return std::string string representation of input Kind
+static std::string getParameterKindStr(ParameterKind Kind) {
switch (Kind) {
case ParameterKind::INVALID:
return "INVALID";
@@ -182,92 +190,77 @@ static std::string parameterKindToString(ParameterKind Kind) {
llvm_unreachable("Unknown llvm::dxil::ParameterKind enum");
}
-static void emitDXILOpEnum(DXILOperationDesc &Op, raw_ostream &OS) {
- // Name = ID, // Doc
- OS << Op.OpName << " = " << Op.OpCode << ", // " << Op.Doc << "\n";
-}
+/// Return a string representation of OverloadKind enum that maps to
+/// input Simple Value Type enum
+/// \param VT Simple Value Type enum
+/// \return std::string string representation of OverloadKind
-static std::string buildCategoryStr(StringSet<> &Cetegorys) {
- std::string Str;
- raw_string_ostream OS(Str);
- for (auto &It : Cetegorys) {
- OS << " " << It.getKey();
+static std::string getOverloadKindStr(MVT::SimpleValueType VT) {
+ switch (VT) {
+ case MVT::isVoid:
+ return "OverloadKind::VOID";
+ case MVT::f16:
+ return "OverloadKind::HALF";
+ case MVT::f32:
+ return "OverloadKind::FLOAT";
+ case MVT::f64:
+ return "OverloadKind::DOUBLE";
+ case MVT::i1:
+ return "OverloadKind::I1";
+ case MVT::i8:
+ return "OverloadKind::I8";
+ case MVT::i16:
+ return "OverloadKind::I16";
+ case MVT::i32:
+ return "OverloadKind::I32";
+ case MVT::i64:
+ return "OverloadKind::I64";
+ case MVT::iAny:
+ return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64";
+ case MVT::fAny:
+ return "OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE";
+ default:
+ llvm_unreachable(
+ "Support for specified parameter OverloadKind not yet implemented");
}
- return OS.str();
}
-// Emit enum declaration for DXIL.
+/// Emit Enums of DXIL Ops
+/// \param A vector of DXIL Ops
+/// \param Output stream
static void emitDXILEnums(std::vector<DXILOperationDesc> &Ops,
raw_ostream &OS) {
- // Sort by Category + OpName.
+ // Sort by OpCode
llvm::sort(Ops, [](DXILOperationDesc &A, DXILOperationDesc &B) {
- // Group by Category first.
- if (A.Category == B.Category)
- // Inside same Category, order by OpName.
- return A.OpName < B.OpName;
- else
- return A.Category < B.Category;
+ return A.OpCode < B.OpCode;
});
OS << "// Enumeration for operations specified by DXIL\n";
OS << "enum class OpCode : unsigned {\n";
- StringMap<StringSet<>> ClassMap;
- StringRef PrevCategory = "";
for (auto &Op : Ops) {
- StringRef Category = Op.Category;
- if (Category != PrevCategory) {
- OS << "\n// " << Category << "\n";
- PrevCategory = Category;
- }
- emitDXILOpEnum(Op, OS);
- auto It = ClassMap.find(Op.OpClass);
- if (It != ClassMap.end()) {
- It->second.insert(Op.Category);
- } else {
- ClassMap[Op.OpClass].insert(Op.Category);
- }
+ // Name = ID, // Doc
+ OS << Op.OpName << " = " << Op.OpCode << ", // " << Op.Doc << "\n";
}
OS << "\n};\n\n";
- std::vector<std::pair<std::string, std::string>> ClassVec;
- for (auto &It : ClassMap) {
- ClassVec.emplace_back(
- std::pair(It.getKey().str(), buildCategoryStr(It.second)));
- }
- // Sort by Category + ClassName.
- llvm::sort(ClassVec, [](std::pair<std::string, std::string> &A,
- std::pair<std::string, std::string> &B) {
- StringRef ClassA = A.first;
- StringRef CategoryA = A.second;
- StringRef ClassB = B.first;
- StringRef CategoryB = B.second;
- // Group by Category first.
- if (CategoryA == CategoryB)
- // Inside same Category, order by ClassName.
- return ClassA < ClassB;
- else
- return CategoryA < CategoryB;
- });
-
OS << "// Groups for DXIL operations with equivalent function templates\n";
OS << "enum class OpCodeClass : unsigned {\n";
- PrevCategory = "";
- for (auto &It : ClassVec) {
-
- StringRef Category = It.second;
- if (Category != PrevCategory) {
- OS << "\n// " << Category << "\n";
- PrevCategory = Category;
- }
- StringRef Name = It.first;
- OS << Name << ",\n";
+ // Build an OpClass set to print
+ SmallSet<StringRef, 2> OpClassSet;
+ for (auto &Op : Ops) {
+ OpClassSet.insert(Op.OpClass);
+ }
+ for (auto &C : OpClassSet) {
+ OS << C << ",\n";
}
OS << "\n};\n\n";
}
-// Emit map from llvm intrinsic to DXIL operation.
+/// Emit map of DXIL operation to LLVM or DirectX intrinsic
+/// \param A vector of DXIL Ops
+/// \param Output stream
static void emitDXILIntrinsicMap(std::vector<DXILOperationDesc> &Ops,
raw_ostream &OS) {
OS << "\n";
@@ -285,75 +278,27 @@ static void emitDXILIntrinsicMap(std::vector<DXILOperationDesc> &Ops,
OS << "\n";
}
-/*!
- Convert operation attribute string to Attribute enum
-
- @param Attr string reference
- @return std::string Attribute enum string
- */
-static std::string emitDXILOperationAttr(StringRef Attr) {
- return StringSwitch<std::string>(Attr)
- .Case("ReadNone", "Attribute::ReadNone")
- .Case("ReadOnly", "Attribute::ReadOnly")
- .Default("Attribute::None");
-}
-
-static std::string overloadKindStr(ParameterKind Overload) {
- switch (Overload) {
- case ParameterKind::HALF:
- return "OverloadKind::HALF";
- case ParameterKind::FLOAT:
- return "OverloadKind::FLOAT";
- case ParameterKind::DOUBLE:
- return "OverloadKind::DOUBLE";
- case ParameterKind::I1:
- return "OverloadKind::I1";
- case ParameterKind::I8:
- return "OverloadKind::I8";
- case ParameterKind::I16:
- return "OverloadKind::I16";
- case ParameterKind::I32:
- return "OverloadKind::I32";
- case ParameterKind::I64:
- return "OverloadKind::I64";
- case ParameterKind::VOID:
- return "OverloadKind::VOID";
- default:
- return "OverloadKind::UNKNOWN";
- }
-}
-
-static std::string
-getDXILOperationOverloads(SmallVector<ParameterKind> Overloads) {
- // Format is: OverloadKind::FLOAT | OverloadKind::HALF
- auto It = Overloads.begin();
- std::string Result;
- raw_string_ostream OS(Result);
- OS << overloadKindStr(*It);
- for (++It; It != Overloads.end(); ++It) {
- OS << " | " << overloadKindStr(*It);
+/// Convert operation attribute string to Attribute enum
+///
+/// \param Attr string reference
+/// \return std::string Attribute enum string
+
+static std::string emitDXILOperationAttr(SmallVector<std::string> Attrs) {
+ for (auto Attr : Attrs) {
+ // TODO: For now just recognize IntrNoMem and IntrReadMem as valid and
+ // ignore others.
+ if (Attr == "IntrNoMem") {
+ return "Attribute::ReadNone";
+ } else if (Attr == "IntrReadMem") {
+ return "Attribute::ReadOnly";
+ }
}
- return OS.str();
-}
-
-static std::string lowerFirstLetter(StringRef Name) {
- if (Name.empty())
- return "";
-
- std::string LowerName = Name.str();
- LowerName[0] = llvm::toLower(Name[0]);
- return LowerName;
-}
-
-static std::string getDXILOpClassName(StringRef OpClass) {
- // Lower first letter expect for special case.
- return StringSwitch<std::string>(OpClass)
- .Case("CBufferLoad", "cbufferLoad")
- .Case("CBufferLoadLegacy", "cbufferLoadLegacy")
- .Case("GSInstanceID", "gsInstanceID")
- .Default(lowerFirstLetter(OpClass));
+ return "Attribute::None";
}
+/// Emit DXIL operation table
+/// \param A vector of DXIL Ops
+/// \param Output stream
static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
raw_ostream &OS) {
// Sort by OpCode.
@@ -369,15 +314,16 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
StringMap<SmallVector<ParameterKind>> ParameterMap;
StringSet<> ClassSet;
for (auto &Op : Ops) {
- OpStrings.add(Op.OpName.str());
+ OpStrings.add(Op.OpName);
if (ClassSet.contains(Op.OpClass))
continue;
ClassSet.insert(Op.OpClass);
- OpClassStrings.add(getDXILOpClassName(Op.OpClass));
+ OpClassStrings.add(Op.OpClass.data());
SmallVector<ParameterKind> ParamKindVec;
- for (auto &Param : Op.Params) {
- ParamKindVec.emplace_back(Param.Kind);
+ // ParamKindVec is a vector of parameters. Skip return type at index 0
+ for (unsigned i = 1; i < Op.OpTypes.size(); i++) {
+ ParamKindVec.emplace_back(getParameterKind(Op.OpTypes[i]));
}
ParameterMap[Op.OpClass] = ParamKindVec;
Parameters.add(ParamKindVec);
@@ -389,7 +335,7 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
Parameters.layout();
// Emit the DXIL operation table.
- //{dxil::OpCode::Sin, OpCodeNameIndex, OpCodeClass::Unary,
+ //{dxil::OpCode::Sin, OpCodeNameIndex, OpCodeClass::unary,
// OpCodeClassNameIndex,
// OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone, 0,
// 3, ParameterTableOffset},
@@ -398,12 +344,12 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
OS << " static const OpCodeProperty OpCodeProps[] = {\n";
for (auto &Op : Ops) {
- OS << " { dxil::OpCode::" << Op.OpName << ", "
- << OpStrings.get(Op.OpName.str()) << ", OpCodeClass::" << Op.OpClass
- << ", " << OpClassStrings.get(getDXILOpClassName(Op.OpClass)) << ", "
- << getDXILOperationOverloads(Op.OverloadTypes) << ", "
- << emitDXILOperationAttr(Op.Attr) << ", " << Op.OverloadParamIndex
- << ", " << Op.Params.size() << ", "
+ OS << " { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
+ << ", OpCodeClass::" << Op.OpClass << ", "
+ << OpClassStrings.get(Op.OpClass.data()) << ", "
+ << getOverloadKindStr(Op.OpTypes[0]) << ", "
+ << emitDXILOperationAttr(Op.OpAttributes) << ", "
+ << Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
<< Parameters.get(ParameterMap[Op.OpClass]) << " },\n";
}
OS << " };\n";
@@ -418,7 +364,7 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
"OpCodeProperty &B) {\n";
OS << " return A.OpCode < B.OpCode;\n";
OS << " });\n";
- OS << " assert(Prop && \"fail to find OpCodeProperty\");\n";
+ OS << " assert(Prop && \"failed to find OpCodeProperty\");\n";
OS << " return Prop;\n";
OS << "}\n\n";
@@ -450,7 +396,7 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
Parameters.emit(
OS,
[](raw_ostream &ParamOS, ParameterKind Kind) {
- ParamOS << "ParameterKind::" << parameterKindToString(Kind);
+ ParamOS << "ParameterKind::" << getParameterKindStr(Kind);
},
"ParameterKind::INVALID");
OS << " };\n\n";
@@ -459,30 +405,28 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
OS << "}\n ";
}
+/// Entry function call that invokes the functionality of this TableGen backend
+/// \param Records TableGen records of DXIL Operations defined in DXIL.td
+/// \param OS output stream
static void EmitDXILOperation(RecordKeeper &Records, raw_ostream &OS) {
- std::vector<Record *> Ops = Records.getAllDerivedDefinitions("DXILOperation");
OS << "// Generated code, do not edit.\n";
OS << "\n";
-
+ // Get all DXIL Ops to intrinsic mapping records
+ std::vector<Record *> OpIntrMaps =
+ Records.getAllDerivedDefinitions("DXILOpMapping");
std::vector<DXILOperationDesc> DXILOps;
- DXILOps.reserve(Ops.size());
- for (auto *Record : Ops) {
+ for (auto *Record : OpIntrMaps) {
DXILOps.emplace_back(DXILOperationDesc(Record));
}
-
OS << "#ifdef DXIL_OP_ENUM\n";
emitDXILEnums(DXILOps, OS);
OS << "#endif\n\n";
-
OS << "#ifdef DXIL_OP_INTRINSIC_MAP\n";
emitDXILIntrinsicMap(DXILOps, OS);
OS << "#endif\n\n";
-
OS << "#ifdef DXIL_OP_OPERATION_TABLE\n";
emitDXILOperationTable(DXILOps, OS);
OS << "#endif\n\n";
-
- OS << "\n";
}
static TableGen::Emitter::Opt X("gen-dxil-operation", EmitDXILOperation,
More information about the llvm-commits
mailing list