[llvm] [DXIL] Model DXIL Class and Shader Model association of DXIL Ops in DXIL.td (PR #87803)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 5 09:44:51 PDT 2024


https://github.com/bharadwajy created https://github.com/llvm/llvm-project/pull/87803

* Add specification of DXIL Op class and Shader Model.

  Each DXIL Op belongs to a class. A DXIL class represents DXIL Ops 
with the same function prototype (or signature).  This changeset adds 
specification of DXIL op TableGen class. This facilitates usage of the 
prototype information of the DXIL class that a DXIL Op belongs to instead 
of inheriting the return and parameter type information from LLVM Intrinsic.  
Using DXIL class avoids the currently implemented definitions of new 
narrow DXIL types such as `llvm_halforfloat_ty` (hence deleted), is more 
accurate and precise.

* Add specification Shader Model version.

  Each DXIL Op has a set of valid overloads. Validity of overload types depends 
on minimum shader model version. Expressing such constraints in DXIL Ops 
records is needed to ensure valid code generation by DXIL Lowering pass.

  This changeset implements a specification mechanism that associates DXIL Ops 
with the classes they belong to and associates minimum shader mode version
with valid overload types.

* Restructure test of lowering `llvm.sin.*`

  This pattern of tests is expected to facilitate use of same test sources to
test lowering of various combinations of options.

>From 69ef603e7c6e8adf16d53390797261d0f6ac46c9 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 29 Mar 2024 12:11:35 -0400
Subject: [PATCH 1/4] [DXIL] Implement DXIL Ops specification using OpClass
 properties Each DXIL OpClass represents DXIL Ops with the same function
 protitype (signature). Represent this property in a TableGen class and add an
 overload types field with the DXILOpMapping to denote valid overload types of
 a DXIL Op record being defined.

---
 llvm/lib/Target/DirectX/DXIL.td     | 632 ++++++++++++++--------------
 llvm/utils/TableGen/DXILEmitter.cpp | 108 ++---
 2 files changed, 382 insertions(+), 358 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index cd388ed3e3191b..082e7596aded35 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -13,331 +13,349 @@
 
 include "llvm/IR/Intrinsics.td"
 
-class DXILOpClass;
+// Absttraction of DXIL Operation class.
+// It encapsulates an associated function signature viz.,
+// returnty(param1, param2, ...) represented as a list of LLVMTypes.
+// DXIL Ops that belong to a DXILOpClass record the signature of that
+// DXILOpClass
 
-// Following is a set of DXIL Operation classes whose names appear to be
-// arbitrary, yet need to be a substring of the function name used during
-// lowering to DXIL Operation calls. These class name strings are specified
-// as the third argument of add_dixil_op in utils/hct/hctdb.py and case converted
-// in utils/hct/hctdb_instrhelp.py of DirectXShaderCompiler repo. The function
-// name has the format "dx.op.<class-name>.<return-type>".
+class DXILOpClass<list<LLVMType> OpSig> {
+  list<LLVMType> OpSignature = OpSig;
+}
+
+// Concrete definitions of DXIL Op Classes
+// Note that these class name strings are specified as the third argument
+// of add_dixil_op in utils/hct/hctdb.py and case converted in
+// utils/hct/hctdb_instrhelp.py of DirectXShaderCompiler repo. The function
+// name has the format "dx.op.<class-name>.<return-type>", in most cases.
 
 defset list<DXILOpClass> OpClasses = {
-  def acceptHitAndEndSearch : DXILOpClass;
-  def allocateNodeOutputRecords : DXILOpClass;
-  def allocateRayQuery : DXILOpClass;
-  def annotateHandle : DXILOpClass;
-  def annotateNodeHandle : DXILOpClass;
-  def annotateNodeRecordHandle : DXILOpClass;
-  def atomicBinOp : DXILOpClass;
-  def atomicCompareExchange : DXILOpClass;
-  def attributeAtVertex : DXILOpClass;
-  def barrier : DXILOpClass;
-  def barrierByMemoryHandle : DXILOpClass;
-  def barrierByMemoryType : DXILOpClass;
-  def barrierByNodeRecordHandle : DXILOpClass;
-  def binary : DXILOpClass;
-  def binaryWithCarryOrBorrow : DXILOpClass;
-  def binaryWithTwoOuts : DXILOpClass;
-  def bitcastF16toI16 : DXILOpClass;
-  def bitcastF32toI32 : DXILOpClass;
-  def bitcastF64toI64 : DXILOpClass;
-  def bitcastI16toF16 : DXILOpClass;
-  def bitcastI32toF32 : DXILOpClass;
-  def bitcastI64toF64 : DXILOpClass;
-  def bufferLoad : DXILOpClass;
-  def bufferStore : DXILOpClass;
-  def bufferUpdateCounter : DXILOpClass;
-  def calculateLOD : DXILOpClass;
-  def callShader : DXILOpClass;
-  def cbufferLoad : DXILOpClass;
-  def cbufferLoadLegacy : DXILOpClass;
-  def checkAccessFullyMapped : DXILOpClass;
-  def coverage : DXILOpClass;
-  def createHandle : DXILOpClass;
-  def createHandleForLib : DXILOpClass;
-  def createHandleFromBinding : DXILOpClass;
-  def createHandleFromHeap : DXILOpClass;
-  def createNodeInputRecordHandle : DXILOpClass;
-  def createNodeOutputHandle : DXILOpClass;
-  def cutStream : DXILOpClass;
-  def cycleCounterLegacy : DXILOpClass;
-  def discard : DXILOpClass;
-  def dispatchMesh : DXILOpClass;
-  def dispatchRaysDimensions : DXILOpClass;
-  def dispatchRaysIndex : DXILOpClass;
-  def domainLocation : DXILOpClass;
-  def dot2 : DXILOpClass;
-  def dot2AddHalf : DXILOpClass;
-  def dot3 : DXILOpClass;
-  def dot4 : DXILOpClass;
-  def dot4AddPacked : DXILOpClass;
-  def emitIndices : DXILOpClass;
-  def emitStream : DXILOpClass;
-  def emitThenCutStream : DXILOpClass;
-  def evalCentroid : DXILOpClass;
-  def evalSampleIndex : DXILOpClass;
-  def evalSnapped : DXILOpClass;
-  def finishedCrossGroupSharing : DXILOpClass;
-  def flattenedThreadIdInGroup : DXILOpClass;
-  def geometryIndex : DXILOpClass;
-  def getDimensions : DXILOpClass;
-  def getInputRecordCount : DXILOpClass;
-  def getMeshPayload : DXILOpClass;
-  def getNodeRecordPtr : DXILOpClass;
-  def getRemainingRecursionLevels : DXILOpClass;
-  def groupId : DXILOpClass;
-  def gsInstanceID : DXILOpClass;
-  def hitKind : DXILOpClass;
-  def ignoreHit : DXILOpClass;
-  def incrementOutputCount : DXILOpClass;
-  def indexNodeHandle : DXILOpClass;
-  def innerCoverage : DXILOpClass;
-  def instanceID : DXILOpClass;
-  def instanceIndex : DXILOpClass;
-  def isHelperLane : DXILOpClass;
-  def isSpecialFloat : DXILOpClass;
-  def legacyDoubleToFloat : DXILOpClass;
-  def legacyDoubleToSInt32 : DXILOpClass;
-  def legacyDoubleToUInt32 : DXILOpClass;
-  def legacyF16ToF32 : DXILOpClass;
-  def legacyF32ToF16 : DXILOpClass;
-  def loadInput : DXILOpClass;
-  def loadOutputControlPoint : DXILOpClass;
-  def loadPatchConstant : DXILOpClass;
-  def makeDouble : DXILOpClass;
-  def minPrecXRegLoad : DXILOpClass;
-  def minPrecXRegStore : DXILOpClass;
-  def nodeOutputIsValid : DXILOpClass;
-  def objectRayDirection : DXILOpClass;
-  def objectRayOrigin : DXILOpClass;
-  def objectToWorld : DXILOpClass;
-  def outputComplete : DXILOpClass;
-  def outputControlPointID : DXILOpClass;
-  def pack4x8 : DXILOpClass;
-  def primitiveID : DXILOpClass;
-  def primitiveIndex : DXILOpClass;
-  def quadOp : DXILOpClass;
-  def quadReadLaneAt : DXILOpClass;
-  def quadVote : DXILOpClass;
-  def quaternary : DXILOpClass;
-  def rawBufferLoad : DXILOpClass;
-  def rawBufferStore : DXILOpClass;
-  def rayFlags : DXILOpClass;
-  def rayQuery_Abort : DXILOpClass;
-  def rayQuery_CommitNonOpaqueTriangleHit : DXILOpClass;
-  def rayQuery_CommitProceduralPrimitiveHit : DXILOpClass;
-  def rayQuery_Proceed : DXILOpClass;
-  def rayQuery_StateMatrix : DXILOpClass;
-  def rayQuery_StateScalar : DXILOpClass;
-  def rayQuery_StateVector : DXILOpClass;
-  def rayQuery_TraceRayInline : DXILOpClass;
-  def rayTCurrent : DXILOpClass;
-  def rayTMin : DXILOpClass;
-  def renderTargetGetSampleCount : DXILOpClass;
-  def renderTargetGetSamplePosition : DXILOpClass;
-  def reportHit : DXILOpClass;
-  def sample : DXILOpClass;
-  def sampleBias : DXILOpClass;
-  def sampleCmp : DXILOpClass;
-  def sampleCmpBias : DXILOpClass;
-  def sampleCmpGrad : DXILOpClass;
-  def sampleCmpLevel : DXILOpClass;
-  def sampleCmpLevelZero : DXILOpClass;
-  def sampleGrad : DXILOpClass;
-  def sampleIndex : DXILOpClass;
-  def sampleLevel : DXILOpClass;
-  def setMeshOutputCounts : DXILOpClass;
-  def splitDouble : DXILOpClass;
-  def startInstanceLocation : DXILOpClass;
-  def startVertexLocation : DXILOpClass;
-  def storeOutput : DXILOpClass;
-  def storePatchConstant : DXILOpClass;
-  def storePrimitiveOutput : DXILOpClass;
-  def storeVertexOutput : DXILOpClass;
-  def tempRegLoad : DXILOpClass;
-  def tempRegStore : DXILOpClass;
-  def tertiary : DXILOpClass;
-  def texture2DMSGetSamplePosition : DXILOpClass;
-  def textureGather : DXILOpClass;
-  def textureGatherCmp : DXILOpClass;
-  def textureGatherRaw : DXILOpClass;
-  def textureLoad : DXILOpClass;
-  def textureStore : DXILOpClass;
-  def textureStoreSample : DXILOpClass;
-  def threadId : DXILOpClass;
-  def threadIdInGroup : DXILOpClass;
-  def traceRay : DXILOpClass;
-  def unary : DXILOpClass;
-  def unaryBits : DXILOpClass;
-  def unpack4x8 : DXILOpClass;
-  def viewID : DXILOpClass;
-  def waveActiveAllEqual : DXILOpClass;
-  def waveActiveBallot : DXILOpClass;
-  def waveActiveBit : DXILOpClass;
-  def waveActiveOp : DXILOpClass;
-  def waveAllOp : DXILOpClass;
-  def waveAllTrue : DXILOpClass;
-  def waveAnyTrue : DXILOpClass;
-  def waveGetLaneCount : DXILOpClass;
-  def waveGetLaneIndex : DXILOpClass;
-  def waveIsFirstLane : DXILOpClass;
-  def waveMatch : DXILOpClass;
-  def waveMatrix_Accumulate : DXILOpClass;
-  def waveMatrix_Annotate : DXILOpClass;
-  def waveMatrix_Depth : DXILOpClass;
-  def waveMatrix_Fill : DXILOpClass;
-  def waveMatrix_LoadGroupShared : DXILOpClass;
-  def waveMatrix_LoadRawBuf : DXILOpClass;
-  def waveMatrix_Multiply : DXILOpClass;
-  def waveMatrix_ScalarOp : DXILOpClass;
-  def waveMatrix_StoreGroupShared : DXILOpClass;
-  def waveMatrix_StoreRawBuf : DXILOpClass;
-  def waveMultiPrefixBitCount : DXILOpClass;
-  def waveMultiPrefixOp : DXILOpClass;
-  def wavePrefixOp : DXILOpClass;
-  def waveReadLaneAt : DXILOpClass;
-  def waveReadLaneFirst : DXILOpClass;
-  def worldRayDirection : DXILOpClass;
-  def worldRayOrigin : DXILOpClass;
-  def worldToObject : DXILOpClass;
-  def writeSamplerFeedback : DXILOpClass;
-  def writeSamplerFeedbackBias : DXILOpClass;
-  def writeSamplerFeedbackGrad : DXILOpClass;
-  def writeSamplerFeedbackLevel: DXILOpClass;
+//  def acceptHitAndEndSearch : DXILOpClass;
+//  def allocateNodeOutputRecords : DXILOpClass;
+//  def allocateRayQuery : DXILOpClass;
+//  def annotateHandle : DXILOpClass;
+//  def annotateNodeHandle : DXILOpClass;
+//  def annotateNodeRecordHandle : DXILOpClass;
+//  def atomicBinOp : DXILOpClass;
+//  def atomicCompareExchange : DXILOpClass;
+//  def attributeAtVertex : DXILOpClass;
+//  def barrier : DXILOpClass;
+//  def barrierByMemoryHandle : DXILOpClass;
+//  def barrierByMemoryType : DXILOpClass;
+//  def barrierByNodeRecordHandle : DXILOpClass;
+def binary : DXILOpClass<[llvm_any_ty, llvm_any_ty, llvm_any_ty]>;
+//  def binaryWithCarryOrBorrow : DXILOpClass;
+//  def binaryWithTwoOuts : DXILOpClass;
+//  def bitcastF16toI16 : DXILOpClass;
+//  def bitcastF32toI32 : DXILOpClass;
+//  def bitcastF64toI64 : DXILOpClass;
+//  def bitcastI16toF16 : DXILOpClass;
+//  def bitcastI32toF32 : DXILOpClass;
+//  def bitcastI64toF64 : DXILOpClass;
+//  def bufferLoad : DXILOpClass;
+//  def bufferStore : DXILOpClass;
+//  def bufferUpdateCounter : DXILOpClass;
+//  def calculateLOD : DXILOpClass;
+//  def callShader : DXILOpClass;
+//  def cbufferLoad : DXILOpClass;
+//  def cbufferLoadLegacy : DXILOpClass;
+//  def checkAccessFullyMapped : DXILOpClass;
+//  def coverage : DXILOpClass;
+//  def createHandle : DXILOpClass;
+//  def createHandleForLib : DXILOpClass;
+//  def createHandleFromBinding : DXILOpClass;
+//  def createHandleFromHeap : DXILOpClass;
+//  def createNodeInputRecordHandle : DXILOpClass;
+//  def createNodeOutputHandle : DXILOpClass;
+//  def cutStream : DXILOpClass;
+//  def cycleCounterLegacy : DXILOpClass;
+//  def discard : DXILOpClass;
+//  def dispatchMesh : DXILOpClass;
+//  def dispatchRaysDimensions : DXILOpClass;
+//  def dispatchRaysIndex : DXILOpClass;
+//  def domainLocation : DXILOpClass;
+def dot2 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 5)>;
+//  def dot2AddHalf : DXILOpClass;
+def dot3 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 7)>;
+def dot4 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 9)>;
+//  def dot4AddPacked : DXILOpClass;
+//  def emitIndices : DXILOpClass;
+//  def emitStream : DXILOpClass;
+//  def emitThenCutStream : DXILOpClass;
+//  def evalCentroid : DXILOpClass;
+//  def evalSampleIndex : DXILOpClass;
+//  def evalSnapped : DXILOpClass;
+//  def finishedCrossGroupSharing : DXILOpClass;
+def flattenedThreadIdInGroup : DXILOpClass<[llvm_i32_ty]>;
+//  def geometryIndex : DXILOpClass;
+//  def getDimensions : DXILOpClass;
+//  def getInputRecordCount : DXILOpClass;
+//  def getMeshPayload : DXILOpClass;
+//  def getNodeRecordPtr : DXILOpClass;
+//  def getRemainingRecursionLevels : DXILOpClass;
+def groupId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+//  def gsInstanceID : DXILOpClass;
+//  def hitKind : DXILOpClass;
+//  def ignoreHit : DXILOpClass;
+//  def incrementOutputCount : DXILOpClass;
+//  def indexNodeHandle : DXILOpClass;
+//  def innerCoverage : DXILOpClass;
+//  def instanceID : DXILOpClass;
+//  def instanceIndex : DXILOpClass;
+//  def isHelperLane : DXILOpClass;
+def isSpecialFloat : DXILOpClass<[llvm_i1_ty, llvm_anyfloat_ty]>;
+//  def legacyDoubleToFloat : DXILOpClass;
+//  def legacyDoubleToSInt32 : DXILOpClass;
+//  def legacyDoubleToUInt32 : DXILOpClass;
+//  def legacyF16ToF32 : DXILOpClass;
+//  def legacyF32ToF16 : DXILOpClass;
+//  def loadInput : DXILOpClass;
+//  def loadOutputControlPoint : DXILOpClass;
+//  def loadPatchConstant : DXILOpClass;
+//  def makeDouble : DXILOpClass;
+//  def minPrecXRegLoad : DXILOpClass;
+//  def minPrecXRegStore : DXILOpClass;
+//  def nodeOutputIsValid : DXILOpClass;
+//  def objectRayDirection : DXILOpClass;
+//  def objectRayOrigin : DXILOpClass;
+//  def objectToWorld : DXILOpClass;
+//  def outputComplete : DXILOpClass;
+//  def outputControlPointID : DXILOpClass;
+//  def pack4x8 : DXILOpClass;
+//  def primitiveID : DXILOpClass;
+//  def primitiveIndex : DXILOpClass;
+//  def quadOp : DXILOpClass;
+//  def quadReadLaneAt : DXILOpClass;
+//  def quadVote : DXILOpClass;
+//  def quaternary : DXILOpClass;
+//  def rawBufferLoad : DXILOpClass;
+//  def rawBufferStore : DXILOpClass;
+//  def rayFlags : DXILOpClass;
+//  def rayQuery_Abort : DXILOpClass;
+//  def rayQuery_CommitNonOpaqueTriangleHit : DXILOpClass;
+//  def rayQuery_CommitProceduralPrimitiveHit : DXILOpClass;
+//  def rayQuery_Proceed : DXILOpClass;
+//  def rayQuery_StateMatrix : DXILOpClass;
+//  def rayQuery_StateScalar : DXILOpClass;
+//  def rayQuery_StateVector : DXILOpClass;
+//  def rayQuery_TraceRayInline : DXILOpClass;
+//  def rayTCurrent : DXILOpClass;
+//  def rayTMin : DXILOpClass;
+//  def renderTargetGetSampleCount : DXILOpClass;
+//  def renderTargetGetSamplePosition : DXILOpClass;
+//  def reportHit : DXILOpClass;
+//  def sample : DXILOpClass;
+//  def sampleBias : DXILOpClass;
+//  def sampleCmp : DXILOpClass;
+//  def sampleCmpBias : DXILOpClass;
+//  def sampleCmpGrad : DXILOpClass;
+//  def sampleCmpLevel : DXILOpClass;
+//  def sampleCmpLevelZero : DXILOpClass;
+//  def sampleGrad : DXILOpClass;
+//  def sampleIndex : DXILOpClass;
+//  def sampleLevel : DXILOpClass;
+//  def setMeshOutputCounts : DXILOpClass;
+//  def splitDouble : DXILOpClass;
+//  def startInstanceLocation : DXILOpClass;
+//  def startVertexLocation : DXILOpClass;
+//  def storeOutput : DXILOpClass;
+//  def storePatchConstant : DXILOpClass;
+//  def storePrimitiveOutput : DXILOpClass;
+//  def storeVertexOutput : DXILOpClass;
+//  def tempRegLoad : DXILOpClass;
+//  def tempRegStore : DXILOpClass;
+def tertiary : DXILOpClass<[llvm_any_ty, llvm_any_ty, llvm_any_ty, llvm_any_ty]>;
+//  def texture2DMSGetSamplePosition : DXILOpClass;
+//  def textureGather : DXILOpClass;
+//  def textureGatherCmp : DXILOpClass;
+//  def textureGatherRaw : DXILOpClass;
+//  def textureLoad : DXILOpClass;
+//  def textureStore : DXILOpClass;
+//  def textureStoreSample : DXILOpClass;
+def threadId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+def threadIdInGroup : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+//  def traceRay : DXILOpClass;
+def unary : DXILOpClass<[llvm_any_ty, llvm_any_ty]>;
+//  def unaryBits : DXILOpClass;
+//  def unpack4x8 : DXILOpClass;
+//  def viewID : DXILOpClass;
+//  def waveActiveAllEqual : DXILOpClass;
+//  def waveActiveBallot : DXILOpClass;
+//  def waveActiveBit : DXILOpClass;
+//  def waveActiveOp : DXILOpClass;
+//  def waveAllOp : DXILOpClass;
+//  def waveAllTrue : DXILOpClass;
+//  def waveAnyTrue : DXILOpClass;
+//  def waveGetLaneCount : DXILOpClass;
+//  def waveGetLaneIndex : DXILOpClass;
+//  def waveIsFirstLane : DXILOpClass;
+//  def waveMatch : DXILOpClass;
+//  def waveMatrix_Accumulate : DXILOpClass;
+//  def waveMatrix_Annotate : DXILOpClass;
+//  def waveMatrix_Depth : DXILOpClass;
+//  def waveMatrix_Fill : DXILOpClass;
+//  def waveMatrix_LoadGroupShared : DXILOpClass;
+//  def waveMatrix_LoadRawBuf : DXILOpClass;
+//  def waveMatrix_Multiply : DXILOpClass;
+//  def waveMatrix_ScalarOp : DXILOpClass;
+//  def waveMatrix_StoreGroupShared : DXILOpClass;
+//  def waveMatrix_StoreRawBuf : DXILOpClass;
+//  def waveMultiPrefixBitCount : DXILOpClass;
+//  def waveMultiPrefixOp : DXILOpClass;
+//  def wavePrefixOp : DXILOpClass;
+//  def waveReadLaneAt : DXILOpClass;
+//  def waveReadLaneFirst : DXILOpClass;
+//  def worldRayDirection : DXILOpClass;
+//  def worldRayOrigin : DXILOpClass;
+//  def worldToObject : DXILOpClass;
+//  def writeSamplerFeedback : DXILOpClass;
+//  def writeSamplerFeedbackBias : DXILOpClass;
+//  def writeSamplerFeedbackGrad : DXILOpClass;
+//  def writeSamplerFeedbackLevel: DXILOpClass;
 
   // This is a sentinel definition. Hence placed at the end of the list
   // and not as part of the above alphabetically sorted valid definitions.
   // Additionally it is capitalized unlike all the others.
-  def UnknownOpClass: DXILOpClass;
-}
-
-// Several of the overloaded DXIL Operations support for data types
-// that are a subset of the overloaded LLVM intrinsics that they map to.
-// For e.g., llvm.sin.* intrinsic operates on any floating-point type and
-// maps for lowering to DXIL Op Sin. However, valid overloads of DXIL Sin
-// operation overloads are half (f16) and float (f32) only.
-//
-// The following abstracts overload types specific to DXIL operations.
-
-class DXILType : LLVMType<OtherVT> {
-  let isAny = 1;
-  int isI16OrI32 = 0;
-  int isHalfOrFloat = 0;
+def UnknownOpClass: DXILOpClass<[]>;
 }
 
-// Concrete records for various overload types supported specifically by
-// DXIL Operations.
-let isI16OrI32 = 1 in
-  def llvm_i16ori32_ty : DXILType;
-
-let isHalfOrFloat = 1 in
-  def llvm_halforfloat_ty : DXILType;
-
 // Abstraction DXIL Operation to LLVM intrinsic
 class DXILOpMappingBase {
   int OpCode = 0;                      // Opcode of DXIL Operation
   DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = ?;         // LLVM Intrinsic DXIL Operation maps to
+  list<LLVMType> OpOverloadTypes = ?;      // Fixed types valid for overload type
+                                       // of DXIL Operation
   string Doc = "";                     // A short description of the operation
-  list<LLVMType> OpTypes = ?;          // Valid types of DXIL Operation in the
-                                       // format [returnTy, param1ty, ...]
 }
 
-class DXILOpMapping<int opCode, DXILOpClass opClass,
-                    Intrinsic intrinsic, string doc,
-                    list<LLVMType> opTys = []> : DXILOpMappingBase {
-  int OpCode = opCode;                 // Opcode corresponding to DXIL Operation
-  DXILOpClass OpClass = opClass;       // Class of DXIL Operation.
-  Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
-  string Doc = doc;                    // to a short description of the operation
-  list<LLVMType> OpTypes = !if(!eq(!size(opTys), 0), LLVMIntrinsic.Types, opTys);
+class DXILOpMapping<int opCode,
+                    Intrinsic intrinsic,
+                    string doc> : DXILOpMappingBase {
+  int OpCode = opCode;
+  Intrinsic LLVMIntrinsic = intrinsic;
+  string Doc = doc;
+}
+
+// Concrete definitions of DXIL Operation mapping to corresponding LLVM intrinsic
+
+// IsSpecialFloat Class
+let OpClass = isSpecialFloat, OpOverloadTypes = [llvm_half_ty, llvm_float_ty] in
+  def IsInf : DXILOpMapping<9,  int_dx_isinf,
+                           "Determines if the specified value is infinite.">;
+
+// Unary Class
+let OpClass = unary in {
+  let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in
+    def Abs : DXILOpMapping<6, int_fabs,
+                            "Returns the absolute value of the input.">;
+
+  let OpOverloadTypes = [llvm_half_ty, llvm_float_ty] in {
+    def Cos  : DXILOpMapping<12, int_cos,
+                             "Returns cosine(theta) for theta in radians.">;
+    def Sin  : DXILOpMapping<13, int_sin,
+                             "Returns sine(theta) for theta in radians.">;
+    def Exp2 : DXILOpMapping<21, int_exp2,
+                             "Returns the base 2 exponential, or 2**x, of the"
+                             " specified value. exp2(x) = 2**x.">;
+    def Frac : DXILOpMapping<22, int_dx_frac,
+                             "Returns a fraction from 0 to 1 that represents the"
+                             " decimal part of the input.">;
+    def Log2 : DXILOpMapping<23, int_log2,
+                             "Returns the base-2 logarithm of the specified value.">;
+    def Sqrt : DXILOpMapping<24, int_sqrt,
+                             "Returns the square root of the specified floating-point"
+                             "value, per component.">;
+    def RSqrt : DXILOpMapping<25, int_dx_rsqrt,
+                             "Returns the reciprocal of the square root of the"
+                             " specified value. rsqrt(x) = 1 / sqrt(x).">;
+    def Round : DXILOpMapping<26, int_round,
+                              "Returns the input rounded to the nearest integer"
+                              "within a floating-point type.">;
+    def Floor : DXILOpMapping<27, int_floor,
+                              "Returns the largest integer that is less than or equal to the input.">;
+    def Trunc : DXILOpMapping<29, int_trunc,
+                         "Returns the specified value truncated to the integer component.">;
+  }
+  let OpOverloadTypes = [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty] in {
+    def Rbits : DXILOpMapping<30, int_bitreverse,
+                         "Returns the specified value with its bits reversed.">;
+
+  }
+}
+
+
+// Binary Class
+let OpClass = binary in {
+  // Float overloads
+  let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in {
+    def FMax : DXILOpMapping<35, int_maxnum,
+                             "Float maximum. FMax(a,b) = a > b ? a : b">;
+    def FMin : DXILOpMapping<36, int_minnum,
+                             "Float minimum. FMin(a,b) = a < b ? a : b">;
+  }
+  // Int overloads
+  let OpOverloadTypes =  [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty] in {
+    def SMax : DXILOpMapping<37, int_smax,
+                             "Signed integer maximum. SMax(a,b) = a > b ? a : b">;
+    def SMin : DXILOpMapping<38, int_smin,
+                             "Signed integer minimum. SMin(a,b) = a < b ? a : b">;
+    def UMax : DXILOpMapping<39, int_umax,
+                             "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
+    def UMin : DXILOpMapping<40, int_umin,
+                             "Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
+  }
 }
 
-// Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
-def Abs : DXILOpMapping<6, unary, int_fabs,
-                         "Returns the absolute value of the input.">;
-def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf,
-                         "Determines if the specified value is infinite.",
-                         [llvm_i1_ty, llvm_halforfloat_ty]>;
-def Cos  : DXILOpMapping<12, unary, int_cos,
-                         "Returns cosine(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Sin  : DXILOpMapping<13, unary, int_sin,
-                         "Returns sine(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Exp2 : DXILOpMapping<21, unary, int_exp2,
-                         "Returns the base 2 exponential, or 2**x, of the specified value."
-                         "exp2(x) = 2**x.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Frac : DXILOpMapping<22, unary, int_dx_frac,
-                         "Returns a fraction from 0 to 1 that represents the "
-                         "decimal part of the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Log2 : DXILOpMapping<23, unary, int_log2,
-                         "Returns the base-2 logarithm of the specified value.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Sqrt : DXILOpMapping<24, unary, int_sqrt,
-                         "Returns the square root of the specified floating-point"
-                         "value, per component.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
-                         "Returns the reciprocal of the square root of the specified value."
-                         "rsqrt(x) = 1 / sqrt(x).",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Round : DXILOpMapping<26, unary, int_roundeven,
-                         "Returns the input rounded to the nearest integer"
-                         "within a floating-point type.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Floor : DXILOpMapping<27, unary, int_floor,
-                         "Returns the largest integer that is less than or equal to the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Ceil : DXILOpMapping<28, unary, int_ceil,
-                         "Returns the smallest integer that is greater than or equal to the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Trunc : DXILOpMapping<29, unary, int_trunc,
-                         "Returns the specified value truncated to the integer component.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Rbits : DXILOpMapping<30, unary, int_bitreverse,
-                         "Returns the specified value with its bits reversed.",
-                         [llvm_anyint_ty, LLVMMatchType<0>]>;
-def FMax : DXILOpMapping<35, binary, int_maxnum,
-                         "Float maximum. FMax(a,b) = a > b ? a : b">;
-def FMin : DXILOpMapping<36, binary, int_minnum,
-                         "Float minimum. FMin(a,b) = a < b ? a : b">;
-def SMax : DXILOpMapping<37, binary, int_smax,
-                         "Signed integer maximum. SMax(a,b) = a > b ? a : b">;
-def SMin : DXILOpMapping<38, binary, int_smin,
-                         "Signed integer minimum. SMin(a,b) = a < b ? a : b">;
-def UMax : DXILOpMapping<39, binary, int_umax,
-                         "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
-def UMin : DXILOpMapping<40, binary, int_umin,
-                         "Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
-def FMad : DXILOpMapping<46, tertiary, int_fmuladd,
-                         "Floating point arithmetic multiply/add operation. fmad(m,a,b) = m * a + b.">;
-def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
-                         "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
-def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
-                         "Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
-let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)) in
-  def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
-                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1">;
-let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)) in
-  def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
-                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2">;
-let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)) in
-  def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
-                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3">;
-def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
-                             "Reads the thread ID">;
-def GroupId  : DXILOpMapping<94, groupId, int_dx_group_id,
-                             "Reads the group ID (SV_GroupID)">;
-def ThreadIdInGroup : DXILOpMapping<95, threadIdInGroup,
-                                    int_dx_thread_id_in_group,
+// Tertiary Class
+let OpClass = tertiary in {
+  // Float overloads
+  let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in {
+    def FMad : DXILOpMapping<46, int_fmuladd,
+                             "Floating point arithmetic multiply/add operation."
+                             " fmad(m,a,b) = m * a + b.">;
+  }
+  // Int overloads
+  let OpOverloadTypes =  [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty] in {
+    def IMad : DXILOpMapping<48, int_dx_imad,
+                             "Signed integer arithmetic multiply/add operation."
+                             " imad(m,a,b) = m * a + b.">;
+    def UMad : DXILOpMapping<49, int_dx_umad,
+                         "Unsigned integer arithmetic multiply/add operation."
+                         " umad(m,a, = m * a + b.">;
+  }
+}
+
+// Dot Operations
+let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty] in {
+  let OpClass = dot2 in
+    def Dot2 : DXILOpMapping<54, int_dx_dot2,
+                             "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
+                             " ... + a[n]*b[n] where n is between 0 and 1">;
+  let OpClass = dot3 in
+    def Dot3 : DXILOpMapping<55, int_dx_dot3,
+                             "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
+                             " ... + a[n]*b[n] where n is between 0 and 2">;
+  let OpClass = dot4 in
+    def Dot4 : DXILOpMapping<56, int_dx_dot4,
+                             "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
+                             " ... + a[n]*b[n] where n is between 0 and 3">;
+}
+
+// Thread Operations
+let OpOverloadTypes = [llvm_i32_ty] in {
+  let OpClass =  threadId in
+    def ThreadId : DXILOpMapping<93, int_dx_thread_id,
+                                 "Reads the thread ID">;
+  let OpClass =  groupId in
+    def GroupId  : DXILOpMapping<94, int_dx_group_id,
+                                 "Reads the group ID (SV_GroupID)">;
+  let OpClass =  threadIdInGroup in
+    def ThreadIdInGroup : DXILOpMapping<95, int_dx_thread_id_in_group,
                                     "Reads the thread ID within the group "
                                     "(SV_GroupThreadID)">;
-def FlattenedThreadIdInGroup : DXILOpMapping<96, flattenedThreadIdInGroup,
-                                             int_dx_flattened_thread_id_in_group,
-                                             "Provides a flattened index for a "
-                                             "given thread within a given "
-                                             "group (SV_GroupIndex)">;
+  let OpClass = flattenedThreadIdInGroup in
+    def FlattenedThreadIdInGroup : DXILOpMapping<96, int_dx_flattened_thread_id_in_group,
+                                                 "Provides a flattened index for a given"
+                                                 " thread within a given group (SV_GroupIndex)">;
+}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index f2504775d557f2..77f0d2cb8c937f 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -20,6 +20,7 @@
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/CodeGenTypes/MachineValueType.h"
 #include "llvm/Support/DXILABI.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
 #include <string>
@@ -41,6 +42,8 @@ struct DXILOperationDesc {
   StringRef Doc;      // the documentation description of this instruction
   SmallVector<Record *> OpTypes; // Vector of operand type records -
                                  // return type is at index 0
+  SmallVector<Record *> OpOverloads; // Vector of fixed types valid for
+                                     // operation overloads
   SmallVector<std::string>
       OpAttributes;     // operation attribute represented as strings
   StringRef Intrinsic;  // The llvm intrinsic map to OpName. Default is "" which
@@ -91,15 +94,12 @@ static ParameterKind getParameterKind(const Record *R) {
     return ParameterKind::I32;
   case MVT::fAny:
   case MVT::iAny:
+  case MVT::Any:
     return ParameterKind::Overload;
-  case MVT::Other:
-    // Handle DXIL-specific overload types
-    if (R->getValueAsInt("isHalfOrFloat") || R->getValueAsInt("isI16OrI32")) {
-      return ParameterKind::Overload;
-    }
-    LLVM_FALLTHROUGH;
   default:
-    llvm_unreachable("Support for specified DXIL Type not yet implemented");
+    report_fatal_error(
+        "Support for specified parameter type not yet implemented",
+        /*gen_crash_diag*/ false);
   }
 }
 
@@ -113,9 +113,9 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   OpCode = R->getValueAsInt("OpCode");
 
   Doc = R->getValueAsString("Doc");
-
-  auto TypeRecs = R->getValueAsListOfDefs("OpTypes");
-  unsigned TypeRecsSize = TypeRecs.size();
+  Record *OpClassRec = R->getValueAsDef("OpClass");
+  auto ParamTypeRecs = OpClassRec->getValueAsListOfDefs("OpSignature");
+  unsigned ParamTypeRecsSize = ParamTypeRecs.size();
   // Populate OpTypes with return type and parameter types
 
   // Parameter indices of overloaded parameters.
@@ -124,32 +124,32 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   // the comment before the definition of class LLVMMatchType in
   // llvm/IR/Intrinsics.td
   SmallVector<int> OverloadParamIndices;
-  for (unsigned i = 0; i < TypeRecsSize; i++) {
-    auto TR = TypeRecs[i];
+  for (unsigned I = 0; I < ParamTypeRecsSize; I++) {
+    auto TR = ParamTypeRecs[I];
     // Track operation parameter indices of any overload types
-    auto isAny = TR->getValueAsInt("isAny");
-    if (isAny == 1) {
+    auto IsAny = TR->getValueAsInt("isAny");
+    if (IsAny == 1) {
       // TODO: At present it is expected that all overload types in a DXIL Op
       // are of the same type. Hence, OverloadParamIndices will have only one
       // element. This implies we do not need a vector. However, until more
       // (all?) DXIL Ops are added in DXIL.td, a vector is being used to flag
       // cases this assumption would not hold.
       if (!OverloadParamIndices.empty()) {
-        bool knownType = true;
+        bool KnownType = true;
         // Ensure that the same overload type registered earlier is being used
         for (auto Idx : OverloadParamIndices) {
-          if (TR != TypeRecs[Idx]) {
-            knownType = false;
+          if (TR != ParamTypeRecs[Idx]) {
+            KnownType = false;
             break;
           }
         }
-        if (!knownType) {
+        if (!KnownType) {
           report_fatal_error("Specification of multiple differing overload "
                              "parameter types not yet supported",
-                             false);
+                             /*gen_crash_diag*/ false);
         }
       } else {
-        OverloadParamIndices.push_back(i);
+        OverloadParamIndices.push_back(I);
       }
     }
     // Populate OpTypes array according to the type specification
@@ -160,7 +160,7 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
       // Get the parameter index of anonymous type, TR, references
       auto OLParamIndex = TR->getValueAsInt("Number");
       // Resolve and insert the type to that at OLParamIndex
-      OpTypes.emplace_back(TypeRecs[OLParamIndex]);
+      OpTypes.emplace_back(ParamTypeRecs[OLParamIndex]);
     } else {
       // A non-anonymous type. Just record it in OpTypes
       OpTypes.emplace_back(TR);
@@ -172,9 +172,18 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   if (!OverloadParamIndices.empty()) {
     if (OverloadParamIndices.size() > 1)
       report_fatal_error("Multiple overload type specification not supported",
-                         false);
+                         /*gen_crash_diag*/ false);
     OverloadParamIndex = OverloadParamIndices[0];
   }
+
+  // Get valid overload types of the Operation
+  auto OverloadTypeRecs = R->getValueAsListOfDefs("OpOverloadTypes");
+  unsigned OverloadTypeRecsSize = OverloadTypeRecs.size();
+  // Populate OpOverloads with
+  for (unsigned I = 0; I < OverloadTypeRecsSize; I++) {
+    OpOverloads.emplace_back(OverloadTypeRecs[I]);
+  }
+
   // Get the operation class
   OpClass = R->getValueAsDef("OpClass")->getName();
 
@@ -188,10 +197,10 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
     // that of the intrinsic. Deviations are expected to be encoded in TableGen
     // record specification and handled accordingly here. Support to be added
     // as needed.
-    auto IntrPropList = IntrinsicDef->getValueAsListInit("IntrProperties");
+    ListInit *IntrPropList = IntrinsicDef->getValueAsListInit("IntrProperties");
     auto IntrPropListSize = IntrPropList->size();
-    for (unsigned i = 0; i < IntrPropListSize; i++) {
-      OpAttributes.emplace_back(IntrPropList->getElement(i)->getAsString());
+    for (unsigned I = 0; I < IntrPropListSize; I++) {
+      OpAttributes.emplace_back(IntrPropList->getElement(I)->getAsString());
     }
   }
 }
@@ -233,16 +242,9 @@ static std::string getParameterKindStr(ParameterKind Kind) {
   llvm_unreachable("Unknown llvm::dxil::ParameterKind enum");
 }
 
-/// Return a string representation of OverloadKind enum that maps to
-/// input LLVMType record
-/// \param R TableGen def record of class LLVMType
-/// \return std::string string representation of OverloadKind
-
 static std::string getOverloadKindStr(const Record *R) {
-  auto VTRec = R->getValueAsDef("VT");
+  Record *VTRec = R->getValueAsDef("VT");
   switch (getValueType(VTRec)) {
-  case MVT::isVoid:
-    return "OverloadKind::VOID";
   case MVT::f16:
     return "OverloadKind::HALF";
   case MVT::f32:
@@ -259,24 +261,28 @@ static std::string getOverloadKindStr(const Record *R) {
     return "OverloadKind::I32";
   case MVT::i64:
     return "OverloadKind::I64";
-  case MVT::iAny:
-    return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64";
-  case MVT::fAny:
-    return "OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE";
-  case MVT::Other:
-    // Handle DXIL-specific overload types
-    {
-      if (R->getValueAsInt("isHalfOrFloat")) {
-        return "OverloadKind::HALF | OverloadKind::FLOAT";
-      } else if (R->getValueAsInt("isI16OrI32")) {
-        return "OverloadKind::I16 | OverloadKind::I32";
-      }
-    }
-    LLVM_FALLTHROUGH;
   default:
-    llvm_unreachable(
-        "Support for specified parameter OverloadKind not yet implemented");
+    llvm_unreachable("Support for specified fixed type option for overload "
+                     "type not supported");
+  }
+}
+/// Return a string representation of OverloadKind enum that maps to
+/// input LLVMType record
+/// \param R TableGen def record of class LLVMType
+/// \return std::string string representation of OverloadKind
+
+static std::string
+getOverloadKindStrs(const SmallVector<Record *> OverloadTys) {
+  if (OverloadTys.empty()) {
+    return {};
+  }
+  std::string OverloadString = "";
+  auto Iter = OverloadTys.begin();
+  OverloadString.append(getOverloadKindStr(*Iter++));
+  for (; Iter != OverloadTys.end(); ++Iter) {
+    OverloadString.append(" | ").append(getOverloadKindStr(*Iter));
   }
+  return OverloadString;
 }
 
 /// Emit Enums of DXIL Ops
@@ -411,7 +417,7 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
     OS << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
        << ", OpCodeClass::" << Op.OpClass << ", "
        << OpClassStrings.get(Op.OpClass.data()) << ", "
-       << getOverloadKindStr(Op.OpTypes[OLParamIdx]) << ", "
+       << getOverloadKindStrs(Op.OpOverloads) << ", "
        << emitDXILOperationAttr(Op.OpAttributes) << ", "
        << Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
        << Parameters.get(ParameterMap[Op.OpClass]) << " },\n";
@@ -477,7 +483,7 @@ static void EmitDXILOperation(RecordKeeper &Records, raw_ostream &OS) {
   OS << "\n";
   // Get all DXIL Ops to intrinsic mapping records
   std::vector<Record *> OpIntrMaps =
-      Records.getAllDerivedDefinitions("DXILOpMapping");
+      Records.getAllDerivedDefinitions("DXILOpMappingBase");
   std::vector<DXILOperationDesc> DXILOps;
   for (auto *Record : OpIntrMaps) {
     DXILOps.emplace_back(DXILOperationDesc(Record));

>From 6dd3b750fbf36275c4aacc198b3c13605ad35367 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 29 Mar 2024 18:33:05 -0400
Subject: [PATCH 2/4] Add support to specify overload types specific to Shader
 Model for TableGen records of DXIL Opeartions.

---
 llvm/lib/Target/DirectX/DXIL.td            | 240 +++++++++++----------
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp  |  48 ++++-
 llvm/lib/Target/DirectX/DXILOpBuilder.h    |  15 +-
 llvm/lib/Target/DirectX/DXILOpLowering.cpp |  23 +-
 llvm/test/CodeGen/DirectX/abs.ll           |   2 +-
 llvm/test/CodeGen/DirectX/ceil.ll          |   3 +-
 llvm/test/CodeGen/DirectX/ceil_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/clamp.ll         |   2 +-
 llvm/test/CodeGen/DirectX/cos.ll           |   2 +-
 llvm/test/CodeGen/DirectX/cos_error.ll     |   2 +-
 llvm/test/CodeGen/DirectX/dot2_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/dot3_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/dot4_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/exp.ll           |   2 +-
 llvm/test/CodeGen/DirectX/exp2_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/fabs.ll          |   3 +-
 llvm/test/CodeGen/DirectX/fdot.ll          |   2 +-
 llvm/test/CodeGen/DirectX/floor.ll         |   2 +-
 llvm/test/CodeGen/DirectX/floor_error.ll   |   2 +-
 llvm/test/CodeGen/DirectX/fmax.ll          |   2 +-
 llvm/test/CodeGen/DirectX/fmin.ll          |   2 +-
 llvm/test/CodeGen/DirectX/frac_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/idot.ll          |   2 +-
 llvm/test/CodeGen/DirectX/isinf.ll         |   2 +-
 llvm/test/CodeGen/DirectX/isinf_error.ll   |   2 +-
 llvm/test/CodeGen/DirectX/log.ll           |   2 +-
 llvm/test/CodeGen/DirectX/log10.ll         |   2 +-
 llvm/test/CodeGen/DirectX/log2.ll          |   2 +-
 llvm/test/CodeGen/DirectX/log2_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/pow.ll           |   2 +-
 llvm/test/CodeGen/DirectX/reversebits.ll   |   2 +-
 llvm/test/CodeGen/DirectX/round.ll         |   2 +-
 llvm/test/CodeGen/DirectX/round_error.ll   |   2 +-
 llvm/test/CodeGen/DirectX/rsqrt.ll         |   2 +-
 llvm/test/CodeGen/DirectX/rsqrt_error.ll   |   2 +-
 llvm/test/CodeGen/DirectX/sin.ll           |   6 +-
 llvm/test/CodeGen/DirectX/sin_error.ll     |   2 +-
 llvm/test/CodeGen/DirectX/sin_sm_error.ll  |  15 ++
 llvm/test/CodeGen/DirectX/smax.ll          |   2 +-
 llvm/test/CodeGen/DirectX/smin.ll          |   2 +-
 llvm/test/CodeGen/DirectX/sqrt.ll          |   2 +-
 llvm/test/CodeGen/DirectX/sqrt_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/trunc.ll         |   2 +-
 llvm/test/CodeGen/DirectX/trunc_error.ll   |   2 +-
 llvm/test/CodeGen/DirectX/umax.ll          |   2 +-
 llvm/test/CodeGen/DirectX/umin.ll          |   2 +-
 llvm/utils/TableGen/DXILEmitter.cpp        |  52 +++--
 47 files changed, 298 insertions(+), 183 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/sin_sm_error.ll

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 082e7596aded35..0ea0f016fd6c74 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -13,9 +13,33 @@
 
 include "llvm/IR/Intrinsics.td"
 
-// Absttraction of DXIL Operation class.
+// Abstract class to demarcate minimum Shader model version required
+// to support DXIL Op
+class DXILShaderModel<int major, int minor> {
+  int MajorAndMinor = !add(!mul(major, 10), minor);
+}
+
+// Valid minimum Shader model version records
+
+// Shader Mode 6.x
+foreach i = 0...9 in {
+  def SM6_#i : DXILShaderModel<6, i>;
+}
+// Shader Mode 7.x - for now 7.0 is defined. Extend as needed
+foreach i = 0 in {
+  def SM7_#i : DXILShaderModel<7, i>;
+}
+
+// Abstraction of class mapping valid DXIL Op overloads the minimum
+// version of Shader Model they are supported
+class DXILOpOverload<DXILShaderModel minsm, list<LLVMType> overloads> {
+  DXILShaderModel ShaderModel = minsm;
+  list<LLVMType> OpOverloads = overloads;
+}
+
+// Abstraction of DXIL Operation class.
 // It encapsulates an associated function signature viz.,
-// returnty(param1, param2, ...) represented as a list of LLVMTypes.
+// returnTy(param1Ty, param2Ty, ...) represented as a list of LLVMTypes.
 // DXIL Ops that belong to a DXILOpClass record the signature of that
 // DXILOpClass
 
@@ -30,21 +54,21 @@ class DXILOpClass<list<LLVMType> OpSig> {
 // name has the format "dx.op.<class-name>.<return-type>", in most cases.
 
 defset list<DXILOpClass> OpClasses = {
-//  def acceptHitAndEndSearch : DXILOpClass;
+def acceptHitAndEndSearch : DXILOpClass<[llvm_void_ty]>;
 //  def allocateNodeOutputRecords : DXILOpClass;
-//  def allocateRayQuery : DXILOpClass;
+def allocateRayQuery : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
 //  def annotateHandle : DXILOpClass;
 //  def annotateNodeHandle : DXILOpClass;
 //  def annotateNodeRecordHandle : DXILOpClass;
 //  def atomicBinOp : DXILOpClass;
 //  def atomicCompareExchange : DXILOpClass;
-//  def attributeAtVertex : DXILOpClass;
-//  def barrier : DXILOpClass;
+def attributeAtVertex : DXILOpClass<[llvm_any_ty, llvm_i32_ty, llvm_i32_ty, llvm_i8_ty, llvm_i8_ty]>;
+def barrier : DXILOpClass<[llvm_void_ty, llvm_i32_ty]>;
 //  def barrierByMemoryHandle : DXILOpClass;
-//  def barrierByMemoryType : DXILOpClass;
+def barrierByMemoryType : DXILOpClass<[llvm_void_ty, llvm_i32_ty, llvm_i32_ty]>;
 //  def barrierByNodeRecordHandle : DXILOpClass;
-def binary : DXILOpClass<[llvm_any_ty, llvm_any_ty, llvm_any_ty]>;
-//  def binaryWithCarryOrBorrow : DXILOpClass;
+def binary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>]>;
+def binaryWithCarryOrBorrow : DXILOpClass<[llvm_i32_ty, llvm_any_ty, LLVMMatchType<0>]>;
 //  def binaryWithTwoOuts : DXILOpClass;
 //  def bitcastF16toI16 : DXILOpClass;
 //  def bitcastF32toI32 : DXILOpClass;
@@ -164,7 +188,7 @@ def isSpecialFloat : DXILOpClass<[llvm_i1_ty, llvm_anyfloat_ty]>;
 //  def storeVertexOutput : DXILOpClass;
 //  def tempRegLoad : DXILOpClass;
 //  def tempRegStore : DXILOpClass;
-def tertiary : DXILOpClass<[llvm_any_ty, llvm_any_ty, llvm_any_ty, llvm_any_ty]>;
+def tertiary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 //  def texture2DMSGetSamplePosition : DXILOpClass;
 //  def textureGather : DXILOpClass;
 //  def textureGatherCmp : DXILOpClass;
@@ -175,7 +199,7 @@ def tertiary : DXILOpClass<[llvm_any_ty, llvm_any_ty, llvm_any_ty, llvm_any_ty]>
 def threadId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
 def threadIdInGroup : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
 //  def traceRay : DXILOpClass;
-def unary : DXILOpClass<[llvm_any_ty, llvm_any_ty]>;
+def unary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>]>;
 //  def unaryBits : DXILOpClass;
 //  def unpack4x8 : DXILOpClass;
 //  def viewID : DXILOpClass;
@@ -224,138 +248,128 @@ class DXILOpMappingBase {
   int OpCode = 0;                      // Opcode of DXIL Operation
   DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = ?;         // LLVM Intrinsic DXIL Operation maps to
-  list<LLVMType> OpOverloadTypes = ?;      // Fixed types valid for overload type
+  list<DXILOpOverload> OpOverloadTypes = ?; // Valid overload type
                                        // of DXIL Operation
   string Doc = "";                     // A short description of the operation
 }
 
 class DXILOpMapping<int opCode,
                     Intrinsic intrinsic,
+                    list<DXILOpOverload> overloadTypes,
                     string doc> : DXILOpMappingBase {
   int OpCode = opCode;
   Intrinsic LLVMIntrinsic = intrinsic;
+  list<DXILOpOverload> OpOverloadTypes = overloadTypes;
   string Doc = doc;
 }
 
 // Concrete definitions of DXIL Operation mapping to corresponding LLVM intrinsic
 
 // IsSpecialFloat Class
-let OpClass = isSpecialFloat, OpOverloadTypes = [llvm_half_ty, llvm_float_ty] in
-  def IsInf : DXILOpMapping<9,  int_dx_isinf,
+let OpClass = isSpecialFloat in {
+  def IsInf : DXILOpMapping<9,  int_dx_isinf, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
                            "Determines if the specified value is infinite.">;
+}
 
 // Unary Class
 let OpClass = unary in {
-  let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in
-    def Abs : DXILOpMapping<6, int_fabs,
-                            "Returns the absolute value of the input.">;
-
-  let OpOverloadTypes = [llvm_half_ty, llvm_float_ty] in {
-    def Cos  : DXILOpMapping<12, int_cos,
-                             "Returns cosine(theta) for theta in radians.">;
-    def Sin  : DXILOpMapping<13, int_sin,
-                             "Returns sine(theta) for theta in radians.">;
-    def Exp2 : DXILOpMapping<21, int_exp2,
-                             "Returns the base 2 exponential, or 2**x, of the"
-                             " specified value. exp2(x) = 2**x.">;
-    def Frac : DXILOpMapping<22, int_dx_frac,
-                             "Returns a fraction from 0 to 1 that represents the"
-                             " decimal part of the input.">;
-    def Log2 : DXILOpMapping<23, int_log2,
-                             "Returns the base-2 logarithm of the specified value.">;
-    def Sqrt : DXILOpMapping<24, int_sqrt,
-                             "Returns the square root of the specified floating-point"
-                             "value, per component.">;
-    def RSqrt : DXILOpMapping<25, int_dx_rsqrt,
-                             "Returns the reciprocal of the square root of the"
-                             " specified value. rsqrt(x) = 1 / sqrt(x).">;
-    def Round : DXILOpMapping<26, int_round,
-                              "Returns the input rounded to the nearest integer"
-                              "within a floating-point type.">;
-    def Floor : DXILOpMapping<27, int_floor,
-                              "Returns the largest integer that is less than or equal to the input.">;
-    def Trunc : DXILOpMapping<29, int_trunc,
-                         "Returns the specified value truncated to the integer component.">;
-  }
-  let OpOverloadTypes = [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty] in {
-    def Rbits : DXILOpMapping<30, int_bitreverse,
-                         "Returns the specified value with its bits reversed.">;
+  def Abs : DXILOpMapping<6, int_fabs, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
+                          "Returns the absolute value of the input.">;
 
-  }
+  def Cos  : DXILOpMapping<12, int_cos, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                          "Returns cosine(theta) for theta in radians.">;
+  def Sin  : DXILOpMapping<13, int_sin, [DXILOpOverload<SM6_3, [llvm_half_ty, llvm_float_ty]>,
+                                         DXILOpOverload<SM6_0, [llvm_float_ty]>],
+                           "Returns sine(theta) for theta in radians.">;
+  def Exp2 : DXILOpMapping<21, int_exp2, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                           "Returns the base 2 exponential, or 2**x, of the"
+                           " specified value. exp2(x) = 2**x.">;
+  def Frac : DXILOpMapping<22, int_dx_frac, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns a fraction from 0 to 1 that represents the"
+                            " decimal part of the input.">;
+  def Log2 : DXILOpMapping<23, int_log2, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                           "Returns the base-2 logarithm of the specified value.">;
+  def Sqrt : DXILOpMapping<24, int_sqrt, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                           "Returns the square root of the specified floating-point"
+                           "value, per component.">;
+  def RSqrt : DXILOpMapping<25, int_dx_rsqrt, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the reciprocal of the square root of the"
+                            " specified value. rsqrt(x) = 1 / sqrt(x).">;
+  def Round : DXILOpMapping<26, int_roundeven, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the input rounded to the nearest integer"
+                            "within a floating-point type.">;
+  def Floor : DXILOpMapping<27, int_floor, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the largest integer that is less than or equal to the input.">;
+  def Ceil  : DXILOpMapping<28, int_ceil, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the smallest integer that is greater than or equal to the input.">;
+  def Trunc : DXILOpMapping<29, int_trunc, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the specified value truncated to the integer component.">;
+  def Rbits : DXILOpMapping<30, int_bitreverse, [DXILOpOverload<SM6_0, [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                            "Returns the specified value with its bits reversed.">;
 }
 
-
 // Binary Class
 let OpClass = binary in {
-  // Float overloads
-  let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in {
-    def FMax : DXILOpMapping<35, int_maxnum,
-                             "Float maximum. FMax(a,b) = a > b ? a : b">;
-    def FMin : DXILOpMapping<36, int_minnum,
-                             "Float minimum. FMin(a,b) = a < b ? a : b">;
-  }
-  // Int overloads
-  let OpOverloadTypes =  [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty] in {
-    def SMax : DXILOpMapping<37, int_smax,
-                             "Signed integer maximum. SMax(a,b) = a > b ? a : b">;
-    def SMin : DXILOpMapping<38, int_smin,
-                             "Signed integer minimum. SMin(a,b) = a < b ? a : b">;
-    def UMax : DXILOpMapping<39, int_umax,
-                             "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
-    def UMin : DXILOpMapping<40, int_umin,
-                             "Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
-  }
+// Float overloads
+  def FMax : DXILOpMapping<35, int_maxnum, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
+                           "Float maximum. FMax(a,b) = a > b ? a : b">;
+  def FMin : DXILOpMapping<36, int_minnum, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
+                           "Float minimum. FMin(a,b) = a < b ? a : b">;
+// Int overloads
+  def SMax : DXILOpMapping<37, int_smax, [DXILOpOverload<SM6_0,[llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                           "Signed integer maximum. SMax(a,b) = a > b ? a : b">;
+  def SMin : DXILOpMapping<38, int_smin, [DXILOpOverload<SM6_0,[llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                           "Signed integer minimum. SMin(a,b) = a < b ? a : b">;
+  def UMax : DXILOpMapping<39, int_umax, [DXILOpOverload<SM6_0,[llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                           "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
+  def UMin : DXILOpMapping<40, int_umin, [DXILOpOverload<SM6_0,[llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                           "Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
 }
 
 // Tertiary Class
 let OpClass = tertiary in {
-  // Float overloads
-  let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in {
-    def FMad : DXILOpMapping<46, int_fmuladd,
-                             "Floating point arithmetic multiply/add operation."
-                             " fmad(m,a,b) = m * a + b.">;
-  }
-  // Int overloads
-  let OpOverloadTypes =  [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty] in {
-    def IMad : DXILOpMapping<48, int_dx_imad,
-                             "Signed integer arithmetic multiply/add operation."
-                             " imad(m,a,b) = m * a + b.">;
-    def UMad : DXILOpMapping<49, int_dx_umad,
-                         "Unsigned integer arithmetic multiply/add operation."
-                         " umad(m,a, = m * a + b.">;
-  }
+// Float overloads
+//   let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in {
+  def FMad : DXILOpMapping<46, int_fmuladd, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
+                            "Floating point arithmetic multiply/add operation."
+                            " fmad(m,a,b) = m * a + b.">;
+// Int overloads
+def IMad : DXILOpMapping<48, int_dx_imad, [DXILOpOverload<SM6_0, [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                         "Signed integer arithmetic multiply/add operation."
+                          " imad(m,a,b) = m * a + b.">;
+def UMad : DXILOpMapping<49, int_dx_umad, [DXILOpOverload<SM6_0, [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                        "Unsigned integer arithmetic multiply/add operation."
+                        " umad(m,a, = m * a + b.">;
 }
 
 // Dot Operations
-let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty] in {
-  let OpClass = dot2 in
-    def Dot2 : DXILOpMapping<54, int_dx_dot2,
-                             "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
-                             " ... + a[n]*b[n] where n is between 0 and 1">;
-  let OpClass = dot3 in
-    def Dot3 : DXILOpMapping<55, int_dx_dot3,
-                             "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
-                             " ... + a[n]*b[n] where n is between 0 and 2">;
-  let OpClass = dot4 in
-    def Dot4 : DXILOpMapping<56, int_dx_dot4,
-                             "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
-                             " ... + a[n]*b[n] where n is between 0 and 3">;
-}
+// let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty] in
+let OpClass = dot2 in
+  def Dot2 : DXILOpMapping<54, int_dx_dot2, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                          "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
+                          " ... + a[n]*b[n] where n is between 0 and 1">;
+let OpClass = dot3 in
+  def Dot3 : DXILOpMapping<55, int_dx_dot3, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
+                           " ... + a[n]*b[n] where n is between 0 and 2">;
+let OpClass = dot4 in
+   def Dot4 : DXILOpMapping<56, int_dx_dot4, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
+                            " ... + a[n]*b[n] where n is between 0 and 3">;
 
 // Thread Operations
-let OpOverloadTypes = [llvm_i32_ty] in {
-  let OpClass =  threadId in
-    def ThreadId : DXILOpMapping<93, int_dx_thread_id,
-                                 "Reads the thread ID">;
-  let OpClass =  groupId in
-    def GroupId  : DXILOpMapping<94, int_dx_group_id,
-                                 "Reads the group ID (SV_GroupID)">;
-  let OpClass =  threadIdInGroup in
-    def ThreadIdInGroup : DXILOpMapping<95, int_dx_thread_id_in_group,
-                                    "Reads the thread ID within the group "
-                                    "(SV_GroupThreadID)">;
-  let OpClass = flattenedThreadIdInGroup in
-    def FlattenedThreadIdInGroup : DXILOpMapping<96, int_dx_flattened_thread_id_in_group,
-                                                 "Provides a flattened index for a given"
-                                                 " thread within a given group (SV_GroupIndex)">;
-}
+let OpClass =  threadId in
+  def ThreadId : DXILOpMapping<93, int_dx_thread_id, [DXILOpOverload<SM6_0, [llvm_i32_ty]>],
+                              "Reads the thread ID">;
+let OpClass =  groupId in
+  def GroupId  : DXILOpMapping<94, int_dx_group_id, [DXILOpOverload<SM6_0, [llvm_i32_ty]>],
+                               "Reads the group ID (SV_GroupID)">;
+let OpClass =  threadIdInGroup in
+  def ThreadIdInGroup : DXILOpMapping<95, int_dx_thread_id_in_group, [DXILOpOverload<SM6_0, [llvm_i32_ty]>],
+                                     "Reads the thread ID within the group "
+                                     "(SV_GroupThreadID)">;
+let OpClass = flattenedThreadIdInGroup in
+  def FlattenedThreadIdInGroup : DXILOpMapping<96, int_dx_flattened_thread_id_in_group,
+                                               [DXILOpOverload<SM6_0, [llvm_i32_ty]>],
+                                                "Provides a flattened index for a given"
+                                                " thread within a given group (SV_GroupIndex)">;
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 0b3982ea0f438a..1a4f8c709c0fd2 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -15,6 +15,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/Support/DXILABI.h"
 #include "llvm/Support/ErrorHandling.h"
+#include <string>
 
 using namespace llvm;
 using namespace llvm::dxil;
@@ -123,6 +124,11 @@ static std::string getTypeName(OverloadKind Kind, Type *Ty) {
   }
 }
 
+struct OpSMOverloadProp {
+  uint16_t ShaderModelVer;
+  uint16_t ValidTys;
+};
+
 // Static properties.
 struct OpCodeProperty {
   dxil::OpCode OpCode;
@@ -131,7 +137,7 @@ struct OpCodeProperty {
   dxil::OpCodeClass OpCodeClass;
   // Offset in DXILOpCodeClassNameTable.
   unsigned OpCodeClassNameOffset;
-  uint16_t OverloadTys;
+  std::vector<OpSMOverloadProp> OverloadProp;
   llvm::Attribute::AttrKind FuncAttr;
   int OverloadParamIndex;        // parameter index which control the overload.
                                  // When < 0, should be only 1 overload type.
@@ -249,16 +255,38 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
       ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
 }
 
+static uint16_t getValidOverloadMask(const OpCodeProperty *Prop,
+                                     uint32_t SMVer) {
+  uint16_t ValidTyMask = 0;
+  // std::vector Prop->OverloadProp is in ascending order of SM Version
+  // Overloads of highest SM version that is not greater than SMVer
+  // are the ones that are valid for SMVer.
+  for (auto OL : Prop->OverloadProp) {
+    if (OL.ShaderModelVer <= SMVer) {
+      ValidTyMask = OL.ValidTys;
+    } else {
+      break;
+    }
+  }
+  return ValidTyMask;
+}
+
 namespace llvm {
 namespace dxil {
 
-CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
-                                          Type *OverloadTy,
+CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, uint32_t SMVer,
+                                          Type *ReturnTy, Type *OverloadTy,
                                           SmallVector<Value *> Args) {
   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
+  uint16_t ValidTyMask = getValidOverloadMask(Prop, SMVer);
 
+  if (ValidTyMask == 0) {
+    report_fatal_error(StringRef(std::to_string(SMVer).append(
+                           ": Unhandled Shader Model Version")),
+                       /*gen_crash_diag*/ false);
+  }
   OverloadKind Kind = getOverloadKind(OverloadTy);
-  if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
+  if ((ValidTyMask & (uint16_t)Kind) == 0) {
     report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);
   }
 
@@ -276,14 +304,22 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
   return B.CreateCall(DXILFn, Args);
 }
 
-Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
+Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, uint32_t SMVer,
+                                   FunctionType *FT) {
 
   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
   // If DXIL Op has no overload parameter, just return the
   // precise return type specified.
   if (Prop->OverloadParamIndex < 0) {
     auto &Ctx = FT->getContext();
-    switch (Prop->OverloadTys) {
+    uint16_t ValidTyMask = getValidOverloadMask(Prop, SMVer);
+    if (ValidTyMask == 0) {
+      report_fatal_error(StringRef(std::to_string(SMVer).append(
+                             ": Unhandled Shader Model Version")),
+                         /*gen_crash_diag*/ false);
+    }
+
+    switch (ValidTyMask) {
     case OverloadKind::VOID:
       return Type::getVoidTy(Ctx);
     case OverloadKind::HALF:
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 5babeae470178b..1e15286c810a8a 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -14,6 +14,7 @@
 
 #include "DXILConstants.h"
 #include "llvm/ADT/SmallVector.h"
+#include <cstdint>
 
 namespace llvm {
 class Module;
@@ -30,13 +31,17 @@ class DXILOpBuilder {
 public:
   DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {}
   /// Create an instruction that calls DXIL Op with return type, specified
-  /// opcode, and call arguments. \param OpCode Opcode of the DXIL Op call
-  /// constructed \param ReturnTy Return type of the DXIL Op call constructed
+  /// opcode, and call arguments.
+  ///
+  /// \param OpCode Opcode of the DXIL Op call constructed
+  /// \param SMVer Shader Model Version of DXIL Op call to construct
+  /// \param ReturnTy Return type of the DXIL Op call constructed
   /// \param OverloadTy Overload type of the DXIL Op call constructed
   /// \return DXIL Op call constructed
-  CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
-                             Type *OverloadTy, SmallVector<Value *> Args);
-  Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
+  CallInst *createDXILOpCall(dxil::OpCode OpCode, uint32_t SMVer,
+                             Type *ReturnTy, Type *OverloadTy,
+                             SmallVector<Value *> Args);
+  Type *getOverloadTy(dxil::OpCode OpCode, uint32_t SMVer, FunctionType *FT);
   static const char *getOpCodeName(dxil::OpCode DXILOp);
 
 private:
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index f09e322f88e1fd..c3217d51ece1b9 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -22,6 +22,7 @@
 #include "llvm/IR/IntrinsicsDirectX.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/MC/TargetRegistry.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/ErrorHandling.h"
 
@@ -72,10 +73,26 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
   return NewOperands;
 }
 
+static uint32_t getShaderModelVer(Module &M) {
+  std::string TTStr = M.getTargetTriple();
+  std::string Error;
+  auto Target = TargetRegistry::lookupTarget(TTStr, Error);
+  if (!Target) {
+    if (TTStr.empty()) {
+      report_fatal_error(StringRef(Error), /*gen_crash_diag*/ false);
+    }
+  }
+  auto Major = Triple(TTStr).getOSVersion().getMajor();
+  auto MinorOrErr = Triple(TTStr).getOSVersion().getMinor();
+  uint32_t Minor = MinorOrErr.has_value() ? *MinorOrErr : 0;
+  return ((Major * 10) + Minor);
+}
+
 static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
   IRBuilder<> B(M.getContext());
   DXILOpBuilder DXILB(M, B);
-  Type *OverloadTy = DXILB.getOverloadTy(DXILOp, F.getFunctionType());
+  uint32_t SMVer = getShaderModelVer(M);
+  Type *OverloadTy = DXILB.getOverloadTy(DXILOp, SMVer, F.getFunctionType());
   for (User *U : make_early_inc_range(F.users())) {
     CallInst *CI = dyn_cast<CallInst>(U);
     if (!CI)
@@ -91,8 +108,8 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
     } else
       Args.append(CI->arg_begin(), CI->arg_end());
 
-    CallInst *DXILCI =
-        DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args);
+    CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, SMVer, F.getReturnType(),
+                                              OverloadTy, Args);
 
     CI->replaceAllUsesWith(DXILCI);
     CI->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/abs.ll b/llvm/test/CodeGen/DirectX/abs.ll
index 822580e8c089af..73bc00042ac161 100644
--- a/llvm/test/CodeGen/DirectX/abs.ll
+++ b/llvm/test/CodeGen/DirectX/abs.ll
@@ -1,5 +1,5 @@
 ; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
-; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+; RUN: opt -S  -dxil-op-lower  -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
 
 ; Make sure dxil operation function calls for abs are generated for int16_t/int/int64_t.
 
diff --git a/llvm/test/CodeGen/DirectX/ceil.ll b/llvm/test/CodeGen/DirectX/ceil.ll
index 15854714678014..c7583cb70ba19f 100644
--- a/llvm/test/CodeGen/DirectX/ceil.ll
+++ b/llvm/test/CodeGen/DirectX/ceil.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for ceil are generated for float and half.
 
@@ -18,3 +18,4 @@ entry:
 
 declare half @llvm.ceil.f16(half)
 declare float @llvm.ceil.f32(float)
+
diff --git a/llvm/test/CodeGen/DirectX/ceil_error.ll b/llvm/test/CodeGen/DirectX/ceil_error.ll
index 1b554d8715566e..da6f083550186c 100644
--- a/llvm/test/CodeGen/DirectX/ceil_error.ll
+++ b/llvm/test/CodeGen/DirectX/ceil_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation ceil does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/clamp.ll b/llvm/test/CodeGen/DirectX/clamp.ll
index f122313b8d7dcc..2f29e4479f9ca1 100644
--- a/llvm/test/CodeGen/DirectX/clamp.ll
+++ b/llvm/test/CodeGen/DirectX/clamp.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for clamp/uclamp are generated for half/float/double/i16/i32/i64.
 
diff --git a/llvm/test/CodeGen/DirectX/cos.ll b/llvm/test/CodeGen/DirectX/cos.ll
index 00f2e2c3f6e5ab..72f4bfca23f9d5 100644
--- a/llvm/test/CodeGen/DirectX/cos.ll
+++ b/llvm/test/CodeGen/DirectX/cos.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for cos are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/cos_error.ll b/llvm/test/CodeGen/DirectX/cos_error.ll
index a074f5b493dfd6..6bb85a7cec1e30 100644
--- a/llvm/test/CodeGen/DirectX/cos_error.ll
+++ b/llvm/test/CodeGen/DirectX/cos_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation cos does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/dot2_error.ll b/llvm/test/CodeGen/DirectX/dot2_error.ll
index a27bfaedacd573..54780d18e71fb4 100644
--- a/llvm/test/CodeGen/DirectX/dot2_error.ll
+++ b/llvm/test/CodeGen/DirectX/dot2_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation dot2 does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload
diff --git a/llvm/test/CodeGen/DirectX/dot3_error.ll b/llvm/test/CodeGen/DirectX/dot3_error.ll
index eb69fb145038aa..242716b0b71bad 100644
--- a/llvm/test/CodeGen/DirectX/dot3_error.ll
+++ b/llvm/test/CodeGen/DirectX/dot3_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation dot3 does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload
diff --git a/llvm/test/CodeGen/DirectX/dot4_error.ll b/llvm/test/CodeGen/DirectX/dot4_error.ll
index 5cd632684c0c01..731adda153def8 100644
--- a/llvm/test/CodeGen/DirectX/dot4_error.ll
+++ b/llvm/test/CodeGen/DirectX/dot4_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation dot4 does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload
diff --git a/llvm/test/CodeGen/DirectX/exp.ll b/llvm/test/CodeGen/DirectX/exp.ll
index fdafc1438cf0e8..f67e2744c4ee34 100644
--- a/llvm/test/CodeGen/DirectX/exp.ll
+++ b/llvm/test/CodeGen/DirectX/exp.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for exp are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/exp2_error.ll b/llvm/test/CodeGen/DirectX/exp2_error.ll
index 6b9126785fd4b8..4d13f936eb6be2 100644
--- a/llvm/test/CodeGen/DirectX/exp2_error.ll
+++ b/llvm/test/CodeGen/DirectX/exp2_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation exp2 does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload
diff --git a/llvm/test/CodeGen/DirectX/fabs.ll b/llvm/test/CodeGen/DirectX/fabs.ll
index 3b3f8aa9a4a928..1b3e91dcfb30f6 100644
--- a/llvm/test/CodeGen/DirectX/fabs.ll
+++ b/llvm/test/CodeGen/DirectX/fabs.ll
@@ -1,8 +1,7 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for abs are generated for float, half, and double.
 
-
 ; CHECK-LABEL: fabs_half
 define noundef half @fabs_half(half noundef %a) {
 entry:
diff --git a/llvm/test/CodeGen/DirectX/fdot.ll b/llvm/test/CodeGen/DirectX/fdot.ll
index 3e13b2ad2650c8..bad65cfb1e561f 100644
--- a/llvm/test/CodeGen/DirectX/fdot.ll
+++ b/llvm/test/CodeGen/DirectX/fdot.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s
+; RUN: opt -S  -dxil-op-lower  -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for dot are generated for int/uint vectors.
 
diff --git a/llvm/test/CodeGen/DirectX/floor.ll b/llvm/test/CodeGen/DirectX/floor.ll
index b033e2eaa491e7..f667cab4aa249b 100644
--- a/llvm/test/CodeGen/DirectX/floor.ll
+++ b/llvm/test/CodeGen/DirectX/floor.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for floor are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/floor_error.ll b/llvm/test/CodeGen/DirectX/floor_error.ll
index 3b51a4b543b7f6..e3190e5afb63fa 100644
--- a/llvm/test/CodeGen/DirectX/floor_error.ll
+++ b/llvm/test/CodeGen/DirectX/floor_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation floor does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/fmax.ll b/llvm/test/CodeGen/DirectX/fmax.ll
index aff722c29309c0..05852ee33486d1 100644
--- a/llvm/test/CodeGen/DirectX/fmax.ll
+++ b/llvm/test/CodeGen/DirectX/fmax.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for fmax are generated for half/float/double.
 
diff --git a/llvm/test/CodeGen/DirectX/fmin.ll b/llvm/test/CodeGen/DirectX/fmin.ll
index 2f7c209f0278ae..1c6c7ca3f2e38a 100644
--- a/llvm/test/CodeGen/DirectX/fmin.ll
+++ b/llvm/test/CodeGen/DirectX/fmin.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for fmin are generated for half/float/double.
 
diff --git a/llvm/test/CodeGen/DirectX/frac_error.ll b/llvm/test/CodeGen/DirectX/frac_error.ll
index ebce76105ad4d7..1bc3558ab0c9a5 100644
--- a/llvm/test/CodeGen/DirectX/frac_error.ll
+++ b/llvm/test/CodeGen/DirectX/frac_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation frac does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/idot.ll b/llvm/test/CodeGen/DirectX/idot.ll
index 9f89a8d6d340d5..8fad7b00700f5a 100644
--- a/llvm/test/CodeGen/DirectX/idot.ll
+++ b/llvm/test/CodeGen/DirectX/idot.ll
@@ -1,5 +1,5 @@
 ; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
-; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+; RUN: opt -S  -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library  %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
 
 ; Make sure dxil operation function calls for dot are generated for int/uint vectors.
 
diff --git a/llvm/test/CodeGen/DirectX/isinf.ll b/llvm/test/CodeGen/DirectX/isinf.ll
index e2975da90bfc1b..bbacaa0f99bac6 100644
--- a/llvm/test/CodeGen/DirectX/isinf.ll
+++ b/llvm/test/CodeGen/DirectX/isinf.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s
 
 ; Make sure dxil operation function calls for isinf are generated for float and half.
 ; CHECK: call i1 @dx.op.isSpecialFloat.f32(i32 9, float %{{.*}})
diff --git a/llvm/test/CodeGen/DirectX/isinf_error.ll b/llvm/test/CodeGen/DirectX/isinf_error.ll
index 95b2d0cabcc43b..39b83554d74d0e 100644
--- a/llvm/test/CodeGen/DirectX/isinf_error.ll
+++ b/llvm/test/CodeGen/DirectX/isinf_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation isinf does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/log.ll b/llvm/test/CodeGen/DirectX/log.ll
index 172c3bfed3b770..36344cf7a5f6dd 100644
--- a/llvm/test/CodeGen/DirectX/log.ll
+++ b/llvm/test/CodeGen/DirectX/log.ll
@@ -1,5 +1,5 @@
 ; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
-; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+; RUN: opt -S  -dxil-op-lower  -mtriple=dxil-pc-shadermodel6.3-library  %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
 
 ; Make sure dxil operation function calls for log are generated.
 
diff --git a/llvm/test/CodeGen/DirectX/log10.ll b/llvm/test/CodeGen/DirectX/log10.ll
index d4f827a0d1af83..8e40ccd8d13313 100644
--- a/llvm/test/CodeGen/DirectX/log10.ll
+++ b/llvm/test/CodeGen/DirectX/log10.ll
@@ -1,5 +1,5 @@
 ; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
-; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+; RUN: opt -S  -dxil-op-lower  -mtriple=dxil-pc-shadermodel6.3-library  %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
 
 ; Make sure dxil operation function calls for log10 are generated.
 
diff --git a/llvm/test/CodeGen/DirectX/log2.ll b/llvm/test/CodeGen/DirectX/log2.ll
index 2164d4db9396d1..d6a7ba0b7dda75 100644
--- a/llvm/test/CodeGen/DirectX/log2.ll
+++ b/llvm/test/CodeGen/DirectX/log2.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for log2 are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/log2_error.ll b/llvm/test/CodeGen/DirectX/log2_error.ll
index a26f6e8c3117f5..b8876854d389fb 100644
--- a/llvm/test/CodeGen/DirectX/log2_error.ll
+++ b/llvm/test/CodeGen/DirectX/log2_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation log2 does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/pow.ll b/llvm/test/CodeGen/DirectX/pow.ll
index 25ce0fe731d0ba..4ed886532f5909 100644
--- a/llvm/test/CodeGen/DirectX/pow.ll
+++ b/llvm/test/CodeGen/DirectX/pow.ll
@@ -1,5 +1,5 @@
 ; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
-; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+; RUN: opt -S  -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
 
 ; Make sure dxil operation function calls for pow are generated.
 
diff --git a/llvm/test/CodeGen/DirectX/reversebits.ll b/llvm/test/CodeGen/DirectX/reversebits.ll
index b6a7a1bc6152e3..1ade57b40100ff 100644
--- a/llvm/test/CodeGen/DirectX/reversebits.ll
+++ b/llvm/test/CodeGen/DirectX/reversebits.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for reversebits are generated for all integer types.
 
diff --git a/llvm/test/CodeGen/DirectX/round.ll b/llvm/test/CodeGen/DirectX/round.ll
index e0a3772ebca8fa..db953fb29c2046 100644
--- a/llvm/test/CodeGen/DirectX/round.ll
+++ b/llvm/test/CodeGen/DirectX/round.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for round are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/round_error.ll b/llvm/test/CodeGen/DirectX/round_error.ll
index 2d27fbb5ee20de..9d2a4e778a9249 100644
--- a/llvm/test/CodeGen/DirectX/round_error.ll
+++ b/llvm/test/CodeGen/DirectX/round_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; This test is expected to fail with the following error
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/rsqrt.ll b/llvm/test/CodeGen/DirectX/rsqrt.ll
index 52af0e62220b3e..054c84483ef826 100644
--- a/llvm/test/CodeGen/DirectX/rsqrt.ll
+++ b/llvm/test/CodeGen/DirectX/rsqrt.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for rsqrt are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/rsqrt_error.ll b/llvm/test/CodeGen/DirectX/rsqrt_error.ll
index 9cd5002c20f7ec..5e29e37113d19f 100644
--- a/llvm/test/CodeGen/DirectX/rsqrt_error.ll
+++ b/llvm/test/CodeGen/DirectX/rsqrt_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation rsqrt does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/sin.ll b/llvm/test/CodeGen/DirectX/sin.ll
index 1f285c433581cf..789f6a73fe6aaa 100644
--- a/llvm/test/CodeGen/DirectX/sin.ll
+++ b/llvm/test/CodeGen/DirectX/sin.ll
@@ -1,8 +1,6 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s -check-prefix=SM6_3
 
 ; Make sure dxil operation function calls for sin are generated for float and half.
-; CHECK:call float @dx.op.unary.f32(i32 13, float %{{.*}})
-; CHECK:call half @dx.op.unary.f16(i32 13, half %{{.*}})
 
 ; Function Attrs: noinline nounwind optnone
 define noundef float @sin_float(float noundef %a) #0 {
@@ -10,6 +8,7 @@ entry:
   %a.addr = alloca float, align 4
   store float %a, ptr %a.addr, align 4
   %0 = load float, ptr %a.addr, align 4
+  ; SM6_3: call float @dx.op.unary.f32(i32 13, float %{{.*}})
   %1 = call float @llvm.sin.f32(float %0)
   ret float %1
 }
@@ -20,6 +19,7 @@ entry:
   %a.addr = alloca half, align 2
   store half %a, ptr %a.addr, align 2
   %0 = load half, ptr %a.addr, align 2
+  ; SM6_3: call half @dx.op.unary.f16(i32 13, half %{{.*}})
   %1 = call half @llvm.sin.f16(half %0)
   ret half %1
 }
diff --git a/llvm/test/CodeGen/DirectX/sin_error.ll b/llvm/test/CodeGen/DirectX/sin_error.ll
index ece0e530315b2f..3e954f32d9bb1b 100644
--- a/llvm/test/CodeGen/DirectX/sin_error.ll
+++ b/llvm/test/CodeGen/DirectX/sin_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation sin does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload
diff --git a/llvm/test/CodeGen/DirectX/sin_sm_error.ll b/llvm/test/CodeGen/DirectX/sin_sm_error.ll
new file mode 100644
index 00000000000000..84bf54901f313a
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/sin_sm_error.ll
@@ -0,0 +1,15 @@
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %s 2>&1 | FileCheck %s
+
+; DXIL operation sin does not support half overload type in SM6.0
+; CHECK: LLVM ERROR: Invalid Overload
+
+; Function Attrs: noinline nounwind optnone
+define noundef half @sin_half(half noundef %a) #0 {
+entry:
+  %a.addr = alloca half, align 2
+  store half %a, ptr %a.addr, align 2
+  %0 = load half, ptr %a.addr, align 2
+  ; SM6_3: call half @dx.op.unary.f16(i32 13, half %{{.*}})
+  %1 = call half @llvm.sin.f16(half %0)
+  ret half %1
+}
diff --git a/llvm/test/CodeGen/DirectX/smax.ll b/llvm/test/CodeGen/DirectX/smax.ll
index 8b2406782c0938..bcda51cb0bfba6 100644
--- a/llvm/test/CodeGen/DirectX/smax.ll
+++ b/llvm/test/CodeGen/DirectX/smax.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for smax are generated for i16/i32/i64.
 
diff --git a/llvm/test/CodeGen/DirectX/smin.ll b/llvm/test/CodeGen/DirectX/smin.ll
index b2b40a1b624335..8d4884704df213 100644
--- a/llvm/test/CodeGen/DirectX/smin.ll
+++ b/llvm/test/CodeGen/DirectX/smin.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for smin are generated for i16/i32/i64.
 
diff --git a/llvm/test/CodeGen/DirectX/sqrt.ll b/llvm/test/CodeGen/DirectX/sqrt.ll
index 76a572efd20557..792fbc8d0614d3 100644
--- a/llvm/test/CodeGen/DirectX/sqrt.ll
+++ b/llvm/test/CodeGen/DirectX/sqrt.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for sqrt are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/sqrt_error.ll b/llvm/test/CodeGen/DirectX/sqrt_error.ll
index fffa2e19b80fa9..1477abc62c13a2 100644
--- a/llvm/test/CodeGen/DirectX/sqrt_error.ll
+++ b/llvm/test/CodeGen/DirectX/sqrt_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation sqrt does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/trunc.ll b/llvm/test/CodeGen/DirectX/trunc.ll
index 2072f28cef50a0..f00b737da4dbb3 100644
--- a/llvm/test/CodeGen/DirectX/trunc.ll
+++ b/llvm/test/CodeGen/DirectX/trunc.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for trunc are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/trunc_error.ll b/llvm/test/CodeGen/DirectX/trunc_error.ll
index 751b0b94c280df..ccc7b1df879ee3 100644
--- a/llvm/test/CodeGen/DirectX/trunc_error.ll
+++ b/llvm/test/CodeGen/DirectX/trunc_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation trunc does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/umax.ll b/llvm/test/CodeGen/DirectX/umax.ll
index be0f557fc8da69..a4bd66ef0bd6c3 100644
--- a/llvm/test/CodeGen/DirectX/umax.ll
+++ b/llvm/test/CodeGen/DirectX/umax.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for umax are generated for i16/i32/i64.
 
diff --git a/llvm/test/CodeGen/DirectX/umin.ll b/llvm/test/CodeGen/DirectX/umin.ll
index 5051c711744892..a551f8ff3bfa9d 100644
--- a/llvm/test/CodeGen/DirectX/umin.ll
+++ b/llvm/test/CodeGen/DirectX/umin.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for umin are generated for i16/i32/i64.
 
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 77f0d2cb8c937f..0af90ebee5d2d5 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -23,7 +23,11 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
+
+#include <algorithm>
+#include <cstdint>
 #include <string>
+#include <vector>
 
 using namespace llvm;
 using namespace llvm::dxil;
@@ -177,7 +181,16 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   }
 
   // Get valid overload types of the Operation
-  auto OverloadTypeRecs = R->getValueAsListOfDefs("OpOverloadTypes");
+  std::vector<Record *> OverloadTypeRecs =
+      R->getValueAsListOfDefs("OpOverloadTypes");
+  // Sort records in ascending order of Shader Model version
+  std::sort(
+      OverloadTypeRecs.begin(), OverloadTypeRecs.end(),
+      [](Record *a, Record *b) {
+        return (
+            a->getValueAsDef("ShaderModel")->getValueAsInt("MajorAndMinor") <
+            b->getValueAsDef("ShaderModel")->getValueAsInt("MajorAndMinor"));
+      });
   unsigned OverloadTypeRecsSize = OverloadTypeRecs.size();
   // Populate OpOverloads with
   for (unsigned I = 0; I < OverloadTypeRecsSize; I++) {
@@ -268,20 +281,33 @@ static std::string getOverloadKindStr(const Record *R) {
 }
 /// Return a string representation of OverloadKind enum that maps to
 /// input LLVMType record
-/// \param R TableGen def record of class LLVMType
+/// \param Recs A vector of records of TableGen class type DXILShaderModel
 /// \return std::string string representation of OverloadKind
 
+// Constant value that is used to encode shader model version
+// denoting SM5.0
+
 static std::string
-getOverloadKindStrs(const SmallVector<Record *> OverloadTys) {
-  if (OverloadTys.empty()) {
-    return {};
-  }
+getOverloadKindStrs(const SmallVector<Record *> Recs) {
   std::string OverloadString = "";
-  auto Iter = OverloadTys.begin();
-  OverloadString.append(getOverloadKindStr(*Iter++));
-  for (; Iter != OverloadTys.end(); ++Iter) {
-    OverloadString.append(" | ").append(getOverloadKindStr(*Iter));
+  std::string Prefix = "";
+  OverloadString.append("{");
+  for (auto OvRec : Recs) {
+    OverloadString.append(Prefix).append("{");
+    OverloadString
+        .append(std::to_string(OvRec->getValueAsDef("ShaderModel")
+                                   ->getValueAsInt("MajorAndMinor")))
+        .append(", ");
+    auto OverloadTys = OvRec->getValueAsListOfDefs("OpOverloads");
+    auto Iter = OverloadTys.begin();
+    OverloadString.append(getOverloadKindStr(*Iter++));
+    for (; Iter != OverloadTys.end(); ++Iter) {
+      OverloadString.append(" | ").append(getOverloadKindStr(*Iter));
+    }
+    OverloadString.append("}");
+    Prefix = ", ";
   }
+  OverloadString.append("}");
   return OverloadString;
 }
 
@@ -403,6 +429,7 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
         "{\n";
 
   OS << "  static const OpCodeProperty OpCodeProps[] = {\n";
+  std::string Prefix = "";
   for (auto &Op : Ops) {
     // Consider Op.OverloadParamIndex as the overload parameter index, by
     // default
@@ -414,13 +441,14 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
     if (OLParamIdx < 0) {
       OLParamIdx = (Op.OpTypes.size() > 1) ? 1 : 0;
     }
-    OS << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
+    OS << Prefix << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
        << ", OpCodeClass::" << Op.OpClass << ", "
        << OpClassStrings.get(Op.OpClass.data()) << ", "
        << getOverloadKindStrs(Op.OpOverloads) << ", "
        << emitDXILOperationAttr(Op.OpAttributes) << ", "
        << Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
-       << Parameters.get(ParameterMap[Op.OpClass]) << " },\n";
+       << Parameters.get(ParameterMap[Op.OpClass]) << " }\n";
+       Prefix = ",";
   }
   OS << "  };\n";
 

>From 658ccae81822447938fe84561402287b7e8fd78b Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 5 Apr 2024 11:40:41 -0400
Subject: [PATCH 3/4] Reorganize tests to lower llvm.sin.*. This pattern of
 tests will facilitate use of same test sources to test lowering of various
 combinations of options.

---
 .../{sin_error.ll => Inputs/sin/double.ll}    |  4 --
 llvm/test/CodeGen/DirectX/Inputs/sin/float.ll |  9 ++++
 .../{sin_sm_error.ll => Inputs/sin/half.ll}   |  6 ---
 llvm/test/CodeGen/DirectX/sin.ll              | 42 +++++++++----------
 4 files changed, 29 insertions(+), 32 deletions(-)
 rename llvm/test/CodeGen/DirectX/{sin_error.ll => Inputs/sin/double.ll} (55%)
 create mode 100644 llvm/test/CodeGen/DirectX/Inputs/sin/float.ll
 rename llvm/test/CodeGen/DirectX/{sin_sm_error.ll => Inputs/sin/half.ll} (50%)

diff --git a/llvm/test/CodeGen/DirectX/sin_error.ll b/llvm/test/CodeGen/DirectX/Inputs/sin/double.ll
similarity index 55%
rename from llvm/test/CodeGen/DirectX/sin_error.ll
rename to llvm/test/CodeGen/DirectX/Inputs/sin/double.ll
index 3e954f32d9bb1b..949649e9b5b11c 100644
--- a/llvm/test/CodeGen/DirectX/sin_error.ll
+++ b/llvm/test/CodeGen/DirectX/Inputs/sin/double.ll
@@ -1,7 +1,3 @@
-; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
-
-; DXIL operation sin does not support double overload type
-; CHECK: LLVM ERROR: Invalid Overload
 
 define noundef double @sin_double(double noundef %a) #0 {
 entry:
diff --git a/llvm/test/CodeGen/DirectX/Inputs/sin/float.ll b/llvm/test/CodeGen/DirectX/Inputs/sin/float.ll
new file mode 100644
index 00000000000000..6558385e88d67b
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/Inputs/sin/float.ll
@@ -0,0 +1,9 @@
+; Function Attrs: noinline nounwind optnone
+define noundef float @sin_float(float noundef %a) #0 {
+entry:
+  %a.addr = alloca float, align 4
+  store float %a, ptr %a.addr, align 4
+  %0 = load float, ptr %a.addr, align 4
+  %1 = call float @llvm.sin.f32(float %0)
+  ret float %1
+}
diff --git a/llvm/test/CodeGen/DirectX/sin_sm_error.ll b/llvm/test/CodeGen/DirectX/Inputs/sin/half.ll
similarity index 50%
rename from llvm/test/CodeGen/DirectX/sin_sm_error.ll
rename to llvm/test/CodeGen/DirectX/Inputs/sin/half.ll
index 84bf54901f313a..39fbf3d51701d8 100644
--- a/llvm/test/CodeGen/DirectX/sin_sm_error.ll
+++ b/llvm/test/CodeGen/DirectX/Inputs/sin/half.ll
@@ -1,15 +1,9 @@
-; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %s 2>&1 | FileCheck %s
-
-; DXIL operation sin does not support half overload type in SM6.0
-; CHECK: LLVM ERROR: Invalid Overload
-
 ; Function Attrs: noinline nounwind optnone
 define noundef half @sin_half(half noundef %a) #0 {
 entry:
   %a.addr = alloca half, align 2
   store half %a, ptr %a.addr, align 2
   %0 = load half, ptr %a.addr, align 2
-  ; SM6_3: call half @dx.op.unary.f16(i32 13, half %{{.*}})
   %1 = call half @llvm.sin.f16(half %0)
   ret half %1
 }
diff --git a/llvm/test/CodeGen/DirectX/sin.ll b/llvm/test/CodeGen/DirectX/sin.ll
index 789f6a73fe6aaa..ac8ab1ec48339d 100644
--- a/llvm/test/CodeGen/DirectX/sin.ll
+++ b/llvm/test/CodeGen/DirectX/sin.ll
@@ -1,25 +1,23 @@
-; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s -check-prefix=SM6_3
+// Shader Mode 6.0
+// RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %S/Inputs/sin/half.ll 2>&1 | FileCheck %s -check-prefix=SM6_0_HALF
+// RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %S/Inputs/sin/float.ll | FileCheck %s -check-prefix=SM6_0_FLOAT
+// RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %S/inputs/sin/double.ll 2>&1 | FileCheck %s --check-prefix=SM6_0_DOUBLE
 
-; Make sure dxil operation function calls for sin are generated for float and half.
+// RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %S/Inputs/sin/half.ll | FileCheck %s -check-prefix=SM6_3_HALF
+// RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %S/Inputs/sin/float.ll | FileCheck %s -check-prefix=SM6_3_FLOAT
+// RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %S/inputs/sin/double.ll 2>&1 | FileCheck %s --check-prefix=SM6_3_DOUBLE
 
-; Function Attrs: noinline nounwind optnone
-define noundef float @sin_float(float noundef %a) #0 {
-entry:
-  %a.addr = alloca float, align 4
-  store float %a, ptr %a.addr, align 4
-  %0 = load float, ptr %a.addr, align 4
-  ; SM6_3: call float @dx.op.unary.f32(i32 13, float %{{.*}})
-  %1 = call float @llvm.sin.f32(float %0)
-  ret float %1
-}
+// Float is valid for SM6.0
+// SM6_0_FLOAT: call float @dx.op.unary.f32(i32 13, float %{{.*}})
+
+// Half is not valid for SM6.0
+// SM6_0_HALF: LLVM ERROR: Invalid Overload
+
+// Half and float are valid for SM6.2 and later
+// SM6_3_HALF: call half @dx.op.unary.f16(i32 13, half %{{.*}})
+// SM6_3_FLOAT: call float @dx.op.unary.f32(i32 13, float %{{.*}})
+
+// Double is not valid in any Shader Model version
+// SM6_0_DOUBLE: LLVM ERROR: Invalid Overload
+// SM6_3_DOUBLE: LLVM ERROR: Invalid Overload
 
-; Function Attrs: noinline nounwind optnone
-define noundef half @sin_half(half noundef %a) #0 {
-entry:
-  %a.addr = alloca half, align 2
-  store half %a, ptr %a.addr, align 2
-  %0 = load half, ptr %a.addr, align 2
-  ; SM6_3: call half @dx.op.unary.f16(i32 13, half %{{.*}})
-  %1 = call half @llvm.sin.f16(half %0)
-  ret half %1
-}

>From 0e43baf3efa01fdc142bb110befe71f9b9149c75 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 5 Apr 2024 11:50:47 -0400
Subject: [PATCH 4/4] Delete DXILClass defs that are commented out. Need to add
 corresponding class defs when adding DXIL Ops of a new class.

---
 llvm/lib/Target/DirectX/DXIL.td     | 203 +++-------------------------
 llvm/utils/TableGen/DXILEmitter.cpp |   9 +-
 2 files changed, 24 insertions(+), 188 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 0ea0f016fd6c74..ce8a14cda94a9d 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -53,194 +53,31 @@ class DXILOpClass<list<LLVMType> OpSig> {
 // utils/hct/hctdb_instrhelp.py of DirectXShaderCompiler repo. The function
 // name has the format "dx.op.<class-name>.<return-type>", in most cases.
 
+// NOTE: The following list is not complete. Classes need to be defined as new DXIL Ops
+// are added.
 defset list<DXILOpClass> OpClasses = {
-def acceptHitAndEndSearch : DXILOpClass<[llvm_void_ty]>;
-//  def allocateNodeOutputRecords : DXILOpClass;
-def allocateRayQuery : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
-//  def annotateHandle : DXILOpClass;
-//  def annotateNodeHandle : DXILOpClass;
-//  def annotateNodeRecordHandle : DXILOpClass;
-//  def atomicBinOp : DXILOpClass;
-//  def atomicCompareExchange : DXILOpClass;
-def attributeAtVertex : DXILOpClass<[llvm_any_ty, llvm_i32_ty, llvm_i32_ty, llvm_i8_ty, llvm_i8_ty]>;
-def barrier : DXILOpClass<[llvm_void_ty, llvm_i32_ty]>;
-//  def barrierByMemoryHandle : DXILOpClass;
-def barrierByMemoryType : DXILOpClass<[llvm_void_ty, llvm_i32_ty, llvm_i32_ty]>;
-//  def barrierByNodeRecordHandle : DXILOpClass;
-def binary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>]>;
-def binaryWithCarryOrBorrow : DXILOpClass<[llvm_i32_ty, llvm_any_ty, LLVMMatchType<0>]>;
-//  def binaryWithTwoOuts : DXILOpClass;
-//  def bitcastF16toI16 : DXILOpClass;
-//  def bitcastF32toI32 : DXILOpClass;
-//  def bitcastF64toI64 : DXILOpClass;
-//  def bitcastI16toF16 : DXILOpClass;
-//  def bitcastI32toF32 : DXILOpClass;
-//  def bitcastI64toF64 : DXILOpClass;
-//  def bufferLoad : DXILOpClass;
-//  def bufferStore : DXILOpClass;
-//  def bufferUpdateCounter : DXILOpClass;
-//  def calculateLOD : DXILOpClass;
-//  def callShader : DXILOpClass;
-//  def cbufferLoad : DXILOpClass;
-//  def cbufferLoadLegacy : DXILOpClass;
-//  def checkAccessFullyMapped : DXILOpClass;
-//  def coverage : DXILOpClass;
-//  def createHandle : DXILOpClass;
-//  def createHandleForLib : DXILOpClass;
-//  def createHandleFromBinding : DXILOpClass;
-//  def createHandleFromHeap : DXILOpClass;
-//  def createNodeInputRecordHandle : DXILOpClass;
-//  def createNodeOutputHandle : DXILOpClass;
-//  def cutStream : DXILOpClass;
-//  def cycleCounterLegacy : DXILOpClass;
-//  def discard : DXILOpClass;
-//  def dispatchMesh : DXILOpClass;
-//  def dispatchRaysDimensions : DXILOpClass;
-//  def dispatchRaysIndex : DXILOpClass;
-//  def domainLocation : DXILOpClass;
-def dot2 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 5)>;
-//  def dot2AddHalf : DXILOpClass;
-def dot3 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 7)>;
-def dot4 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 9)>;
-//  def dot4AddPacked : DXILOpClass;
-//  def emitIndices : DXILOpClass;
-//  def emitStream : DXILOpClass;
-//  def emitThenCutStream : DXILOpClass;
-//  def evalCentroid : DXILOpClass;
-//  def evalSampleIndex : DXILOpClass;
-//  def evalSnapped : DXILOpClass;
-//  def finishedCrossGroupSharing : DXILOpClass;
-def flattenedThreadIdInGroup : DXILOpClass<[llvm_i32_ty]>;
-//  def geometryIndex : DXILOpClass;
-//  def getDimensions : DXILOpClass;
-//  def getInputRecordCount : DXILOpClass;
-//  def getMeshPayload : DXILOpClass;
-//  def getNodeRecordPtr : DXILOpClass;
-//  def getRemainingRecursionLevels : DXILOpClass;
-def groupId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
-//  def gsInstanceID : DXILOpClass;
-//  def hitKind : DXILOpClass;
-//  def ignoreHit : DXILOpClass;
-//  def incrementOutputCount : DXILOpClass;
-//  def indexNodeHandle : DXILOpClass;
-//  def innerCoverage : DXILOpClass;
-//  def instanceID : DXILOpClass;
-//  def instanceIndex : DXILOpClass;
-//  def isHelperLane : DXILOpClass;
-def isSpecialFloat : DXILOpClass<[llvm_i1_ty, llvm_anyfloat_ty]>;
-//  def legacyDoubleToFloat : DXILOpClass;
-//  def legacyDoubleToSInt32 : DXILOpClass;
-//  def legacyDoubleToUInt32 : DXILOpClass;
-//  def legacyF16ToF32 : DXILOpClass;
-//  def legacyF32ToF16 : DXILOpClass;
-//  def loadInput : DXILOpClass;
-//  def loadOutputControlPoint : DXILOpClass;
-//  def loadPatchConstant : DXILOpClass;
-//  def makeDouble : DXILOpClass;
-//  def minPrecXRegLoad : DXILOpClass;
-//  def minPrecXRegStore : DXILOpClass;
-//  def nodeOutputIsValid : DXILOpClass;
-//  def objectRayDirection : DXILOpClass;
-//  def objectRayOrigin : DXILOpClass;
-//  def objectToWorld : DXILOpClass;
-//  def outputComplete : DXILOpClass;
-//  def outputControlPointID : DXILOpClass;
-//  def pack4x8 : DXILOpClass;
-//  def primitiveID : DXILOpClass;
-//  def primitiveIndex : DXILOpClass;
-//  def quadOp : DXILOpClass;
-//  def quadReadLaneAt : DXILOpClass;
-//  def quadVote : DXILOpClass;
-//  def quaternary : DXILOpClass;
-//  def rawBufferLoad : DXILOpClass;
-//  def rawBufferStore : DXILOpClass;
-//  def rayFlags : DXILOpClass;
-//  def rayQuery_Abort : DXILOpClass;
-//  def rayQuery_CommitNonOpaqueTriangleHit : DXILOpClass;
-//  def rayQuery_CommitProceduralPrimitiveHit : DXILOpClass;
-//  def rayQuery_Proceed : DXILOpClass;
-//  def rayQuery_StateMatrix : DXILOpClass;
-//  def rayQuery_StateScalar : DXILOpClass;
-//  def rayQuery_StateVector : DXILOpClass;
-//  def rayQuery_TraceRayInline : DXILOpClass;
-//  def rayTCurrent : DXILOpClass;
-//  def rayTMin : DXILOpClass;
-//  def renderTargetGetSampleCount : DXILOpClass;
-//  def renderTargetGetSamplePosition : DXILOpClass;
-//  def reportHit : DXILOpClass;
-//  def sample : DXILOpClass;
-//  def sampleBias : DXILOpClass;
-//  def sampleCmp : DXILOpClass;
-//  def sampleCmpBias : DXILOpClass;
-//  def sampleCmpGrad : DXILOpClass;
-//  def sampleCmpLevel : DXILOpClass;
-//  def sampleCmpLevelZero : DXILOpClass;
-//  def sampleGrad : DXILOpClass;
-//  def sampleIndex : DXILOpClass;
-//  def sampleLevel : DXILOpClass;
-//  def setMeshOutputCounts : DXILOpClass;
-//  def splitDouble : DXILOpClass;
-//  def startInstanceLocation : DXILOpClass;
-//  def startVertexLocation : DXILOpClass;
-//  def storeOutput : DXILOpClass;
-//  def storePatchConstant : DXILOpClass;
-//  def storePrimitiveOutput : DXILOpClass;
-//  def storeVertexOutput : DXILOpClass;
-//  def tempRegLoad : DXILOpClass;
-//  def tempRegStore : DXILOpClass;
-def tertiary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
-//  def texture2DMSGetSamplePosition : DXILOpClass;
-//  def textureGather : DXILOpClass;
-//  def textureGatherCmp : DXILOpClass;
-//  def textureGatherRaw : DXILOpClass;
-//  def textureLoad : DXILOpClass;
-//  def textureStore : DXILOpClass;
-//  def textureStoreSample : DXILOpClass;
-def threadId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
-def threadIdInGroup : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
-//  def traceRay : DXILOpClass;
-def unary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>]>;
-//  def unaryBits : DXILOpClass;
-//  def unpack4x8 : DXILOpClass;
-//  def viewID : DXILOpClass;
-//  def waveActiveAllEqual : DXILOpClass;
-//  def waveActiveBallot : DXILOpClass;
-//  def waveActiveBit : DXILOpClass;
-//  def waveActiveOp : DXILOpClass;
-//  def waveAllOp : DXILOpClass;
-//  def waveAllTrue : DXILOpClass;
-//  def waveAnyTrue : DXILOpClass;
-//  def waveGetLaneCount : DXILOpClass;
-//  def waveGetLaneIndex : DXILOpClass;
-//  def waveIsFirstLane : DXILOpClass;
-//  def waveMatch : DXILOpClass;
-//  def waveMatrix_Accumulate : DXILOpClass;
-//  def waveMatrix_Annotate : DXILOpClass;
-//  def waveMatrix_Depth : DXILOpClass;
-//  def waveMatrix_Fill : DXILOpClass;
-//  def waveMatrix_LoadGroupShared : DXILOpClass;
-//  def waveMatrix_LoadRawBuf : DXILOpClass;
-//  def waveMatrix_Multiply : DXILOpClass;
-//  def waveMatrix_ScalarOp : DXILOpClass;
-//  def waveMatrix_StoreGroupShared : DXILOpClass;
-//  def waveMatrix_StoreRawBuf : DXILOpClass;
-//  def waveMultiPrefixBitCount : DXILOpClass;
-//  def waveMultiPrefixOp : DXILOpClass;
-//  def wavePrefixOp : DXILOpClass;
-//  def waveReadLaneAt : DXILOpClass;
-//  def waveReadLaneFirst : DXILOpClass;
-//  def worldRayDirection : DXILOpClass;
-//  def worldRayOrigin : DXILOpClass;
-//  def worldToObject : DXILOpClass;
-//  def writeSamplerFeedback : DXILOpClass;
-//  def writeSamplerFeedbackBias : DXILOpClass;
-//  def writeSamplerFeedbackGrad : DXILOpClass;
-//  def writeSamplerFeedbackLevel: DXILOpClass;
+  def acceptHitAndEndSearch : DXILOpClass<[llvm_void_ty]>;
+  def allocateRayQuery : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+  def attributeAtVertex : DXILOpClass<[llvm_any_ty, llvm_i32_ty, llvm_i32_ty, llvm_i8_ty, llvm_i8_ty]>;
+  def barrier : DXILOpClass<[llvm_void_ty, llvm_i32_ty]>;
+  def barrierByMemoryType : DXILOpClass<[llvm_void_ty, llvm_i32_ty, llvm_i32_ty]>;
+  def binary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>]>;
+  def binaryWithCarryOrBorrow : DXILOpClass<[llvm_i32_ty, llvm_any_ty, LLVMMatchType<0>]>;
+  def dot2 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 5)>;
+  def dot3 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 7)>;
+  def dot4 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 9)>;
+  def flattenedThreadIdInGroup : DXILOpClass<[llvm_i32_ty]>;
+  def groupId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+  def isSpecialFloat : DXILOpClass<[llvm_i1_ty, llvm_anyfloat_ty]>;
+  def tertiary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+  def threadId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+  def threadIdInGroup : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+  def unary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>]>;
 
   // This is a sentinel definition. Hence placed at the end of the list
   // and not as part of the above alphabetically sorted valid definitions.
   // Additionally it is capitalized unlike all the others.
-def UnknownOpClass: DXILOpClass<[]>;
+  def UnknownOpClass: DXILOpClass<[]>;
 }
 
 // Abstraction DXIL Operation to LLVM intrinsic
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 0af90ebee5d2d5..7c9f3da3e157d8 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -287,8 +287,7 @@ static std::string getOverloadKindStr(const Record *R) {
 // Constant value that is used to encode shader model version
 // denoting SM5.0
 
-static std::string
-getOverloadKindStrs(const SmallVector<Record *> Recs) {
+static std::string getOverloadKindStrs(const SmallVector<Record *> Recs) {
   std::string OverloadString = "";
   std::string Prefix = "";
   OverloadString.append("{");
@@ -441,14 +440,14 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
     if (OLParamIdx < 0) {
       OLParamIdx = (Op.OpTypes.size() > 1) ? 1 : 0;
     }
-    OS << Prefix << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
-       << ", OpCodeClass::" << Op.OpClass << ", "
+    OS << Prefix << "  { dxil::OpCode::" << Op.OpName << ", "
+       << OpStrings.get(Op.OpName) << ", OpCodeClass::" << Op.OpClass << ", "
        << OpClassStrings.get(Op.OpClass.data()) << ", "
        << getOverloadKindStrs(Op.OpOverloads) << ", "
        << emitDXILOperationAttr(Op.OpAttributes) << ", "
        << Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
        << Parameters.get(ParameterMap[Op.OpClass]) << " }\n";
-       Prefix = ",";
+    Prefix = ",";
   }
   OS << "  };\n";
 



More information about the llvm-commits mailing list