[llvm] [DXIL] Model DXIL Class and Shader Model association of DXIL Ops in DXIL.td (PR #87803)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 11 09:17:11 PDT 2024


https://github.com/bharadwajy updated https://github.com/llvm/llvm-project/pull/87803

>From b699f5f90154cc0149a770ec61861de922d3decc Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 29 Mar 2024 12:11:35 -0400
Subject: [PATCH 1/6] [DXIL] Implement DXIL Ops specification using OpClass
 properties Each DXIL OpClass represents DXIL Ops with the same function
 protitype (signature). Represent this property in a TableGen class and add an
 overload types field with the DXILOpMapping to denote valid overload types of
 a DXIL Op record being defined.

---
 llvm/lib/Target/DirectX/DXIL.td     | 632 ++++++++++++++--------------
 llvm/utils/TableGen/DXILEmitter.cpp | 108 ++---
 2 files changed, 382 insertions(+), 358 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index cd388ed3e3191b..082e7596aded35 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -13,331 +13,349 @@
 
 include "llvm/IR/Intrinsics.td"
 
-class DXILOpClass;
+// Absttraction of DXIL Operation class.
+// It encapsulates an associated function signature viz.,
+// returnty(param1, param2, ...) represented as a list of LLVMTypes.
+// DXIL Ops that belong to a DXILOpClass record the signature of that
+// DXILOpClass
 
-// Following is a set of DXIL Operation classes whose names appear to be
-// arbitrary, yet need to be a substring of the function name used during
-// lowering to DXIL Operation calls. These class name strings are specified
-// as the third argument of add_dixil_op in utils/hct/hctdb.py and case converted
-// in utils/hct/hctdb_instrhelp.py of DirectXShaderCompiler repo. The function
-// name has the format "dx.op.<class-name>.<return-type>".
+class DXILOpClass<list<LLVMType> OpSig> {
+  list<LLVMType> OpSignature = OpSig;
+}
+
+// Concrete definitions of DXIL Op Classes
+// Note that these class name strings are specified as the third argument
+// of add_dixil_op in utils/hct/hctdb.py and case converted in
+// utils/hct/hctdb_instrhelp.py of DirectXShaderCompiler repo. The function
+// name has the format "dx.op.<class-name>.<return-type>", in most cases.
 
 defset list<DXILOpClass> OpClasses = {
-  def acceptHitAndEndSearch : DXILOpClass;
-  def allocateNodeOutputRecords : DXILOpClass;
-  def allocateRayQuery : DXILOpClass;
-  def annotateHandle : DXILOpClass;
-  def annotateNodeHandle : DXILOpClass;
-  def annotateNodeRecordHandle : DXILOpClass;
-  def atomicBinOp : DXILOpClass;
-  def atomicCompareExchange : DXILOpClass;
-  def attributeAtVertex : DXILOpClass;
-  def barrier : DXILOpClass;
-  def barrierByMemoryHandle : DXILOpClass;
-  def barrierByMemoryType : DXILOpClass;
-  def barrierByNodeRecordHandle : DXILOpClass;
-  def binary : DXILOpClass;
-  def binaryWithCarryOrBorrow : DXILOpClass;
-  def binaryWithTwoOuts : DXILOpClass;
-  def bitcastF16toI16 : DXILOpClass;
-  def bitcastF32toI32 : DXILOpClass;
-  def bitcastF64toI64 : DXILOpClass;
-  def bitcastI16toF16 : DXILOpClass;
-  def bitcastI32toF32 : DXILOpClass;
-  def bitcastI64toF64 : DXILOpClass;
-  def bufferLoad : DXILOpClass;
-  def bufferStore : DXILOpClass;
-  def bufferUpdateCounter : DXILOpClass;
-  def calculateLOD : DXILOpClass;
-  def callShader : DXILOpClass;
-  def cbufferLoad : DXILOpClass;
-  def cbufferLoadLegacy : DXILOpClass;
-  def checkAccessFullyMapped : DXILOpClass;
-  def coverage : DXILOpClass;
-  def createHandle : DXILOpClass;
-  def createHandleForLib : DXILOpClass;
-  def createHandleFromBinding : DXILOpClass;
-  def createHandleFromHeap : DXILOpClass;
-  def createNodeInputRecordHandle : DXILOpClass;
-  def createNodeOutputHandle : DXILOpClass;
-  def cutStream : DXILOpClass;
-  def cycleCounterLegacy : DXILOpClass;
-  def discard : DXILOpClass;
-  def dispatchMesh : DXILOpClass;
-  def dispatchRaysDimensions : DXILOpClass;
-  def dispatchRaysIndex : DXILOpClass;
-  def domainLocation : DXILOpClass;
-  def dot2 : DXILOpClass;
-  def dot2AddHalf : DXILOpClass;
-  def dot3 : DXILOpClass;
-  def dot4 : DXILOpClass;
-  def dot4AddPacked : DXILOpClass;
-  def emitIndices : DXILOpClass;
-  def emitStream : DXILOpClass;
-  def emitThenCutStream : DXILOpClass;
-  def evalCentroid : DXILOpClass;
-  def evalSampleIndex : DXILOpClass;
-  def evalSnapped : DXILOpClass;
-  def finishedCrossGroupSharing : DXILOpClass;
-  def flattenedThreadIdInGroup : DXILOpClass;
-  def geometryIndex : DXILOpClass;
-  def getDimensions : DXILOpClass;
-  def getInputRecordCount : DXILOpClass;
-  def getMeshPayload : DXILOpClass;
-  def getNodeRecordPtr : DXILOpClass;
-  def getRemainingRecursionLevels : DXILOpClass;
-  def groupId : DXILOpClass;
-  def gsInstanceID : DXILOpClass;
-  def hitKind : DXILOpClass;
-  def ignoreHit : DXILOpClass;
-  def incrementOutputCount : DXILOpClass;
-  def indexNodeHandle : DXILOpClass;
-  def innerCoverage : DXILOpClass;
-  def instanceID : DXILOpClass;
-  def instanceIndex : DXILOpClass;
-  def isHelperLane : DXILOpClass;
-  def isSpecialFloat : DXILOpClass;
-  def legacyDoubleToFloat : DXILOpClass;
-  def legacyDoubleToSInt32 : DXILOpClass;
-  def legacyDoubleToUInt32 : DXILOpClass;
-  def legacyF16ToF32 : DXILOpClass;
-  def legacyF32ToF16 : DXILOpClass;
-  def loadInput : DXILOpClass;
-  def loadOutputControlPoint : DXILOpClass;
-  def loadPatchConstant : DXILOpClass;
-  def makeDouble : DXILOpClass;
-  def minPrecXRegLoad : DXILOpClass;
-  def minPrecXRegStore : DXILOpClass;
-  def nodeOutputIsValid : DXILOpClass;
-  def objectRayDirection : DXILOpClass;
-  def objectRayOrigin : DXILOpClass;
-  def objectToWorld : DXILOpClass;
-  def outputComplete : DXILOpClass;
-  def outputControlPointID : DXILOpClass;
-  def pack4x8 : DXILOpClass;
-  def primitiveID : DXILOpClass;
-  def primitiveIndex : DXILOpClass;
-  def quadOp : DXILOpClass;
-  def quadReadLaneAt : DXILOpClass;
-  def quadVote : DXILOpClass;
-  def quaternary : DXILOpClass;
-  def rawBufferLoad : DXILOpClass;
-  def rawBufferStore : DXILOpClass;
-  def rayFlags : DXILOpClass;
-  def rayQuery_Abort : DXILOpClass;
-  def rayQuery_CommitNonOpaqueTriangleHit : DXILOpClass;
-  def rayQuery_CommitProceduralPrimitiveHit : DXILOpClass;
-  def rayQuery_Proceed : DXILOpClass;
-  def rayQuery_StateMatrix : DXILOpClass;
-  def rayQuery_StateScalar : DXILOpClass;
-  def rayQuery_StateVector : DXILOpClass;
-  def rayQuery_TraceRayInline : DXILOpClass;
-  def rayTCurrent : DXILOpClass;
-  def rayTMin : DXILOpClass;
-  def renderTargetGetSampleCount : DXILOpClass;
-  def renderTargetGetSamplePosition : DXILOpClass;
-  def reportHit : DXILOpClass;
-  def sample : DXILOpClass;
-  def sampleBias : DXILOpClass;
-  def sampleCmp : DXILOpClass;
-  def sampleCmpBias : DXILOpClass;
-  def sampleCmpGrad : DXILOpClass;
-  def sampleCmpLevel : DXILOpClass;
-  def sampleCmpLevelZero : DXILOpClass;
-  def sampleGrad : DXILOpClass;
-  def sampleIndex : DXILOpClass;
-  def sampleLevel : DXILOpClass;
-  def setMeshOutputCounts : DXILOpClass;
-  def splitDouble : DXILOpClass;
-  def startInstanceLocation : DXILOpClass;
-  def startVertexLocation : DXILOpClass;
-  def storeOutput : DXILOpClass;
-  def storePatchConstant : DXILOpClass;
-  def storePrimitiveOutput : DXILOpClass;
-  def storeVertexOutput : DXILOpClass;
-  def tempRegLoad : DXILOpClass;
-  def tempRegStore : DXILOpClass;
-  def tertiary : DXILOpClass;
-  def texture2DMSGetSamplePosition : DXILOpClass;
-  def textureGather : DXILOpClass;
-  def textureGatherCmp : DXILOpClass;
-  def textureGatherRaw : DXILOpClass;
-  def textureLoad : DXILOpClass;
-  def textureStore : DXILOpClass;
-  def textureStoreSample : DXILOpClass;
-  def threadId : DXILOpClass;
-  def threadIdInGroup : DXILOpClass;
-  def traceRay : DXILOpClass;
-  def unary : DXILOpClass;
-  def unaryBits : DXILOpClass;
-  def unpack4x8 : DXILOpClass;
-  def viewID : DXILOpClass;
-  def waveActiveAllEqual : DXILOpClass;
-  def waveActiveBallot : DXILOpClass;
-  def waveActiveBit : DXILOpClass;
-  def waveActiveOp : DXILOpClass;
-  def waveAllOp : DXILOpClass;
-  def waveAllTrue : DXILOpClass;
-  def waveAnyTrue : DXILOpClass;
-  def waveGetLaneCount : DXILOpClass;
-  def waveGetLaneIndex : DXILOpClass;
-  def waveIsFirstLane : DXILOpClass;
-  def waveMatch : DXILOpClass;
-  def waveMatrix_Accumulate : DXILOpClass;
-  def waveMatrix_Annotate : DXILOpClass;
-  def waveMatrix_Depth : DXILOpClass;
-  def waveMatrix_Fill : DXILOpClass;
-  def waveMatrix_LoadGroupShared : DXILOpClass;
-  def waveMatrix_LoadRawBuf : DXILOpClass;
-  def waveMatrix_Multiply : DXILOpClass;
-  def waveMatrix_ScalarOp : DXILOpClass;
-  def waveMatrix_StoreGroupShared : DXILOpClass;
-  def waveMatrix_StoreRawBuf : DXILOpClass;
-  def waveMultiPrefixBitCount : DXILOpClass;
-  def waveMultiPrefixOp : DXILOpClass;
-  def wavePrefixOp : DXILOpClass;
-  def waveReadLaneAt : DXILOpClass;
-  def waveReadLaneFirst : DXILOpClass;
-  def worldRayDirection : DXILOpClass;
-  def worldRayOrigin : DXILOpClass;
-  def worldToObject : DXILOpClass;
-  def writeSamplerFeedback : DXILOpClass;
-  def writeSamplerFeedbackBias : DXILOpClass;
-  def writeSamplerFeedbackGrad : DXILOpClass;
-  def writeSamplerFeedbackLevel: DXILOpClass;
+//  def acceptHitAndEndSearch : DXILOpClass;
+//  def allocateNodeOutputRecords : DXILOpClass;
+//  def allocateRayQuery : DXILOpClass;
+//  def annotateHandle : DXILOpClass;
+//  def annotateNodeHandle : DXILOpClass;
+//  def annotateNodeRecordHandle : DXILOpClass;
+//  def atomicBinOp : DXILOpClass;
+//  def atomicCompareExchange : DXILOpClass;
+//  def attributeAtVertex : DXILOpClass;
+//  def barrier : DXILOpClass;
+//  def barrierByMemoryHandle : DXILOpClass;
+//  def barrierByMemoryType : DXILOpClass;
+//  def barrierByNodeRecordHandle : DXILOpClass;
+def binary : DXILOpClass<[llvm_any_ty, llvm_any_ty, llvm_any_ty]>;
+//  def binaryWithCarryOrBorrow : DXILOpClass;
+//  def binaryWithTwoOuts : DXILOpClass;
+//  def bitcastF16toI16 : DXILOpClass;
+//  def bitcastF32toI32 : DXILOpClass;
+//  def bitcastF64toI64 : DXILOpClass;
+//  def bitcastI16toF16 : DXILOpClass;
+//  def bitcastI32toF32 : DXILOpClass;
+//  def bitcastI64toF64 : DXILOpClass;
+//  def bufferLoad : DXILOpClass;
+//  def bufferStore : DXILOpClass;
+//  def bufferUpdateCounter : DXILOpClass;
+//  def calculateLOD : DXILOpClass;
+//  def callShader : DXILOpClass;
+//  def cbufferLoad : DXILOpClass;
+//  def cbufferLoadLegacy : DXILOpClass;
+//  def checkAccessFullyMapped : DXILOpClass;
+//  def coverage : DXILOpClass;
+//  def createHandle : DXILOpClass;
+//  def createHandleForLib : DXILOpClass;
+//  def createHandleFromBinding : DXILOpClass;
+//  def createHandleFromHeap : DXILOpClass;
+//  def createNodeInputRecordHandle : DXILOpClass;
+//  def createNodeOutputHandle : DXILOpClass;
+//  def cutStream : DXILOpClass;
+//  def cycleCounterLegacy : DXILOpClass;
+//  def discard : DXILOpClass;
+//  def dispatchMesh : DXILOpClass;
+//  def dispatchRaysDimensions : DXILOpClass;
+//  def dispatchRaysIndex : DXILOpClass;
+//  def domainLocation : DXILOpClass;
+def dot2 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 5)>;
+//  def dot2AddHalf : DXILOpClass;
+def dot3 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 7)>;
+def dot4 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 9)>;
+//  def dot4AddPacked : DXILOpClass;
+//  def emitIndices : DXILOpClass;
+//  def emitStream : DXILOpClass;
+//  def emitThenCutStream : DXILOpClass;
+//  def evalCentroid : DXILOpClass;
+//  def evalSampleIndex : DXILOpClass;
+//  def evalSnapped : DXILOpClass;
+//  def finishedCrossGroupSharing : DXILOpClass;
+def flattenedThreadIdInGroup : DXILOpClass<[llvm_i32_ty]>;
+//  def geometryIndex : DXILOpClass;
+//  def getDimensions : DXILOpClass;
+//  def getInputRecordCount : DXILOpClass;
+//  def getMeshPayload : DXILOpClass;
+//  def getNodeRecordPtr : DXILOpClass;
+//  def getRemainingRecursionLevels : DXILOpClass;
+def groupId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+//  def gsInstanceID : DXILOpClass;
+//  def hitKind : DXILOpClass;
+//  def ignoreHit : DXILOpClass;
+//  def incrementOutputCount : DXILOpClass;
+//  def indexNodeHandle : DXILOpClass;
+//  def innerCoverage : DXILOpClass;
+//  def instanceID : DXILOpClass;
+//  def instanceIndex : DXILOpClass;
+//  def isHelperLane : DXILOpClass;
+def isSpecialFloat : DXILOpClass<[llvm_i1_ty, llvm_anyfloat_ty]>;
+//  def legacyDoubleToFloat : DXILOpClass;
+//  def legacyDoubleToSInt32 : DXILOpClass;
+//  def legacyDoubleToUInt32 : DXILOpClass;
+//  def legacyF16ToF32 : DXILOpClass;
+//  def legacyF32ToF16 : DXILOpClass;
+//  def loadInput : DXILOpClass;
+//  def loadOutputControlPoint : DXILOpClass;
+//  def loadPatchConstant : DXILOpClass;
+//  def makeDouble : DXILOpClass;
+//  def minPrecXRegLoad : DXILOpClass;
+//  def minPrecXRegStore : DXILOpClass;
+//  def nodeOutputIsValid : DXILOpClass;
+//  def objectRayDirection : DXILOpClass;
+//  def objectRayOrigin : DXILOpClass;
+//  def objectToWorld : DXILOpClass;
+//  def outputComplete : DXILOpClass;
+//  def outputControlPointID : DXILOpClass;
+//  def pack4x8 : DXILOpClass;
+//  def primitiveID : DXILOpClass;
+//  def primitiveIndex : DXILOpClass;
+//  def quadOp : DXILOpClass;
+//  def quadReadLaneAt : DXILOpClass;
+//  def quadVote : DXILOpClass;
+//  def quaternary : DXILOpClass;
+//  def rawBufferLoad : DXILOpClass;
+//  def rawBufferStore : DXILOpClass;
+//  def rayFlags : DXILOpClass;
+//  def rayQuery_Abort : DXILOpClass;
+//  def rayQuery_CommitNonOpaqueTriangleHit : DXILOpClass;
+//  def rayQuery_CommitProceduralPrimitiveHit : DXILOpClass;
+//  def rayQuery_Proceed : DXILOpClass;
+//  def rayQuery_StateMatrix : DXILOpClass;
+//  def rayQuery_StateScalar : DXILOpClass;
+//  def rayQuery_StateVector : DXILOpClass;
+//  def rayQuery_TraceRayInline : DXILOpClass;
+//  def rayTCurrent : DXILOpClass;
+//  def rayTMin : DXILOpClass;
+//  def renderTargetGetSampleCount : DXILOpClass;
+//  def renderTargetGetSamplePosition : DXILOpClass;
+//  def reportHit : DXILOpClass;
+//  def sample : DXILOpClass;
+//  def sampleBias : DXILOpClass;
+//  def sampleCmp : DXILOpClass;
+//  def sampleCmpBias : DXILOpClass;
+//  def sampleCmpGrad : DXILOpClass;
+//  def sampleCmpLevel : DXILOpClass;
+//  def sampleCmpLevelZero : DXILOpClass;
+//  def sampleGrad : DXILOpClass;
+//  def sampleIndex : DXILOpClass;
+//  def sampleLevel : DXILOpClass;
+//  def setMeshOutputCounts : DXILOpClass;
+//  def splitDouble : DXILOpClass;
+//  def startInstanceLocation : DXILOpClass;
+//  def startVertexLocation : DXILOpClass;
+//  def storeOutput : DXILOpClass;
+//  def storePatchConstant : DXILOpClass;
+//  def storePrimitiveOutput : DXILOpClass;
+//  def storeVertexOutput : DXILOpClass;
+//  def tempRegLoad : DXILOpClass;
+//  def tempRegStore : DXILOpClass;
+def tertiary : DXILOpClass<[llvm_any_ty, llvm_any_ty, llvm_any_ty, llvm_any_ty]>;
+//  def texture2DMSGetSamplePosition : DXILOpClass;
+//  def textureGather : DXILOpClass;
+//  def textureGatherCmp : DXILOpClass;
+//  def textureGatherRaw : DXILOpClass;
+//  def textureLoad : DXILOpClass;
+//  def textureStore : DXILOpClass;
+//  def textureStoreSample : DXILOpClass;
+def threadId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+def threadIdInGroup : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+//  def traceRay : DXILOpClass;
+def unary : DXILOpClass<[llvm_any_ty, llvm_any_ty]>;
+//  def unaryBits : DXILOpClass;
+//  def unpack4x8 : DXILOpClass;
+//  def viewID : DXILOpClass;
+//  def waveActiveAllEqual : DXILOpClass;
+//  def waveActiveBallot : DXILOpClass;
+//  def waveActiveBit : DXILOpClass;
+//  def waveActiveOp : DXILOpClass;
+//  def waveAllOp : DXILOpClass;
+//  def waveAllTrue : DXILOpClass;
+//  def waveAnyTrue : DXILOpClass;
+//  def waveGetLaneCount : DXILOpClass;
+//  def waveGetLaneIndex : DXILOpClass;
+//  def waveIsFirstLane : DXILOpClass;
+//  def waveMatch : DXILOpClass;
+//  def waveMatrix_Accumulate : DXILOpClass;
+//  def waveMatrix_Annotate : DXILOpClass;
+//  def waveMatrix_Depth : DXILOpClass;
+//  def waveMatrix_Fill : DXILOpClass;
+//  def waveMatrix_LoadGroupShared : DXILOpClass;
+//  def waveMatrix_LoadRawBuf : DXILOpClass;
+//  def waveMatrix_Multiply : DXILOpClass;
+//  def waveMatrix_ScalarOp : DXILOpClass;
+//  def waveMatrix_StoreGroupShared : DXILOpClass;
+//  def waveMatrix_StoreRawBuf : DXILOpClass;
+//  def waveMultiPrefixBitCount : DXILOpClass;
+//  def waveMultiPrefixOp : DXILOpClass;
+//  def wavePrefixOp : DXILOpClass;
+//  def waveReadLaneAt : DXILOpClass;
+//  def waveReadLaneFirst : DXILOpClass;
+//  def worldRayDirection : DXILOpClass;
+//  def worldRayOrigin : DXILOpClass;
+//  def worldToObject : DXILOpClass;
+//  def writeSamplerFeedback : DXILOpClass;
+//  def writeSamplerFeedbackBias : DXILOpClass;
+//  def writeSamplerFeedbackGrad : DXILOpClass;
+//  def writeSamplerFeedbackLevel: DXILOpClass;
 
   // This is a sentinel definition. Hence placed at the end of the list
   // and not as part of the above alphabetically sorted valid definitions.
   // Additionally it is capitalized unlike all the others.
-  def UnknownOpClass: DXILOpClass;
-}
-
-// Several of the overloaded DXIL Operations support for data types
-// that are a subset of the overloaded LLVM intrinsics that they map to.
-// For e.g., llvm.sin.* intrinsic operates on any floating-point type and
-// maps for lowering to DXIL Op Sin. However, valid overloads of DXIL Sin
-// operation overloads are half (f16) and float (f32) only.
-//
-// The following abstracts overload types specific to DXIL operations.
-
-class DXILType : LLVMType<OtherVT> {
-  let isAny = 1;
-  int isI16OrI32 = 0;
-  int isHalfOrFloat = 0;
+def UnknownOpClass: DXILOpClass<[]>;
 }
 
-// Concrete records for various overload types supported specifically by
-// DXIL Operations.
-let isI16OrI32 = 1 in
-  def llvm_i16ori32_ty : DXILType;
-
-let isHalfOrFloat = 1 in
-  def llvm_halforfloat_ty : DXILType;
-
 // Abstraction DXIL Operation to LLVM intrinsic
 class DXILOpMappingBase {
   int OpCode = 0;                      // Opcode of DXIL Operation
   DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = ?;         // LLVM Intrinsic DXIL Operation maps to
+  list<LLVMType> OpOverloadTypes = ?;      // Fixed types valid for overload type
+                                       // of DXIL Operation
   string Doc = "";                     // A short description of the operation
-  list<LLVMType> OpTypes = ?;          // Valid types of DXIL Operation in the
-                                       // format [returnTy, param1ty, ...]
 }
 
-class DXILOpMapping<int opCode, DXILOpClass opClass,
-                    Intrinsic intrinsic, string doc,
-                    list<LLVMType> opTys = []> : DXILOpMappingBase {
-  int OpCode = opCode;                 // Opcode corresponding to DXIL Operation
-  DXILOpClass OpClass = opClass;       // Class of DXIL Operation.
-  Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
-  string Doc = doc;                    // to a short description of the operation
-  list<LLVMType> OpTypes = !if(!eq(!size(opTys), 0), LLVMIntrinsic.Types, opTys);
+class DXILOpMapping<int opCode,
+                    Intrinsic intrinsic,
+                    string doc> : DXILOpMappingBase {
+  int OpCode = opCode;
+  Intrinsic LLVMIntrinsic = intrinsic;
+  string Doc = doc;
+}
+
+// Concrete definitions of DXIL Operation mapping to corresponding LLVM intrinsic
+
+// IsSpecialFloat Class
+let OpClass = isSpecialFloat, OpOverloadTypes = [llvm_half_ty, llvm_float_ty] in
+  def IsInf : DXILOpMapping<9,  int_dx_isinf,
+                           "Determines if the specified value is infinite.">;
+
+// Unary Class
+let OpClass = unary in {
+  let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in
+    def Abs : DXILOpMapping<6, int_fabs,
+                            "Returns the absolute value of the input.">;
+
+  let OpOverloadTypes = [llvm_half_ty, llvm_float_ty] in {
+    def Cos  : DXILOpMapping<12, int_cos,
+                             "Returns cosine(theta) for theta in radians.">;
+    def Sin  : DXILOpMapping<13, int_sin,
+                             "Returns sine(theta) for theta in radians.">;
+    def Exp2 : DXILOpMapping<21, int_exp2,
+                             "Returns the base 2 exponential, or 2**x, of the"
+                             " specified value. exp2(x) = 2**x.">;
+    def Frac : DXILOpMapping<22, int_dx_frac,
+                             "Returns a fraction from 0 to 1 that represents the"
+                             " decimal part of the input.">;
+    def Log2 : DXILOpMapping<23, int_log2,
+                             "Returns the base-2 logarithm of the specified value.">;
+    def Sqrt : DXILOpMapping<24, int_sqrt,
+                             "Returns the square root of the specified floating-point"
+                             "value, per component.">;
+    def RSqrt : DXILOpMapping<25, int_dx_rsqrt,
+                             "Returns the reciprocal of the square root of the"
+                             " specified value. rsqrt(x) = 1 / sqrt(x).">;
+    def Round : DXILOpMapping<26, int_round,
+                              "Returns the input rounded to the nearest integer"
+                              "within a floating-point type.">;
+    def Floor : DXILOpMapping<27, int_floor,
+                              "Returns the largest integer that is less than or equal to the input.">;
+    def Trunc : DXILOpMapping<29, int_trunc,
+                         "Returns the specified value truncated to the integer component.">;
+  }
+  let OpOverloadTypes = [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty] in {
+    def Rbits : DXILOpMapping<30, int_bitreverse,
+                         "Returns the specified value with its bits reversed.">;
+
+  }
+}
+
+
+// Binary Class
+let OpClass = binary in {
+  // Float overloads
+  let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in {
+    def FMax : DXILOpMapping<35, int_maxnum,
+                             "Float maximum. FMax(a,b) = a > b ? a : b">;
+    def FMin : DXILOpMapping<36, int_minnum,
+                             "Float minimum. FMin(a,b) = a < b ? a : b">;
+  }
+  // Int overloads
+  let OpOverloadTypes =  [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty] in {
+    def SMax : DXILOpMapping<37, int_smax,
+                             "Signed integer maximum. SMax(a,b) = a > b ? a : b">;
+    def SMin : DXILOpMapping<38, int_smin,
+                             "Signed integer minimum. SMin(a,b) = a < b ? a : b">;
+    def UMax : DXILOpMapping<39, int_umax,
+                             "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
+    def UMin : DXILOpMapping<40, int_umin,
+                             "Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
+  }
 }
 
-// Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
-def Abs : DXILOpMapping<6, unary, int_fabs,
-                         "Returns the absolute value of the input.">;
-def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf,
-                         "Determines if the specified value is infinite.",
-                         [llvm_i1_ty, llvm_halforfloat_ty]>;
-def Cos  : DXILOpMapping<12, unary, int_cos,
-                         "Returns cosine(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Sin  : DXILOpMapping<13, unary, int_sin,
-                         "Returns sine(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Exp2 : DXILOpMapping<21, unary, int_exp2,
-                         "Returns the base 2 exponential, or 2**x, of the specified value."
-                         "exp2(x) = 2**x.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Frac : DXILOpMapping<22, unary, int_dx_frac,
-                         "Returns a fraction from 0 to 1 that represents the "
-                         "decimal part of the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Log2 : DXILOpMapping<23, unary, int_log2,
-                         "Returns the base-2 logarithm of the specified value.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Sqrt : DXILOpMapping<24, unary, int_sqrt,
-                         "Returns the square root of the specified floating-point"
-                         "value, per component.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
-                         "Returns the reciprocal of the square root of the specified value."
-                         "rsqrt(x) = 1 / sqrt(x).",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Round : DXILOpMapping<26, unary, int_roundeven,
-                         "Returns the input rounded to the nearest integer"
-                         "within a floating-point type.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Floor : DXILOpMapping<27, unary, int_floor,
-                         "Returns the largest integer that is less than or equal to the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Ceil : DXILOpMapping<28, unary, int_ceil,
-                         "Returns the smallest integer that is greater than or equal to the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Trunc : DXILOpMapping<29, unary, int_trunc,
-                         "Returns the specified value truncated to the integer component.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Rbits : DXILOpMapping<30, unary, int_bitreverse,
-                         "Returns the specified value with its bits reversed.",
-                         [llvm_anyint_ty, LLVMMatchType<0>]>;
-def FMax : DXILOpMapping<35, binary, int_maxnum,
-                         "Float maximum. FMax(a,b) = a > b ? a : b">;
-def FMin : DXILOpMapping<36, binary, int_minnum,
-                         "Float minimum. FMin(a,b) = a < b ? a : b">;
-def SMax : DXILOpMapping<37, binary, int_smax,
-                         "Signed integer maximum. SMax(a,b) = a > b ? a : b">;
-def SMin : DXILOpMapping<38, binary, int_smin,
-                         "Signed integer minimum. SMin(a,b) = a < b ? a : b">;
-def UMax : DXILOpMapping<39, binary, int_umax,
-                         "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
-def UMin : DXILOpMapping<40, binary, int_umin,
-                         "Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
-def FMad : DXILOpMapping<46, tertiary, int_fmuladd,
-                         "Floating point arithmetic multiply/add operation. fmad(m,a,b) = m * a + b.">;
-def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
-                         "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
-def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
-                         "Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
-let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)) in
-  def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
-                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1">;
-let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)) in
-  def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
-                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2">;
-let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)) in
-  def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
-                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3">;
-def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
-                             "Reads the thread ID">;
-def GroupId  : DXILOpMapping<94, groupId, int_dx_group_id,
-                             "Reads the group ID (SV_GroupID)">;
-def ThreadIdInGroup : DXILOpMapping<95, threadIdInGroup,
-                                    int_dx_thread_id_in_group,
+// Tertiary Class
+let OpClass = tertiary in {
+  // Float overloads
+  let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in {
+    def FMad : DXILOpMapping<46, int_fmuladd,
+                             "Floating point arithmetic multiply/add operation."
+                             " fmad(m,a,b) = m * a + b.">;
+  }
+  // Int overloads
+  let OpOverloadTypes =  [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty] in {
+    def IMad : DXILOpMapping<48, int_dx_imad,
+                             "Signed integer arithmetic multiply/add operation."
+                             " imad(m,a,b) = m * a + b.">;
+    def UMad : DXILOpMapping<49, int_dx_umad,
+                         "Unsigned integer arithmetic multiply/add operation."
+                         " umad(m,a, = m * a + b.">;
+  }
+}
+
+// Dot Operations
+let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty] in {
+  let OpClass = dot2 in
+    def Dot2 : DXILOpMapping<54, int_dx_dot2,
+                             "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
+                             " ... + a[n]*b[n] where n is between 0 and 1">;
+  let OpClass = dot3 in
+    def Dot3 : DXILOpMapping<55, int_dx_dot3,
+                             "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
+                             " ... + a[n]*b[n] where n is between 0 and 2">;
+  let OpClass = dot4 in
+    def Dot4 : DXILOpMapping<56, int_dx_dot4,
+                             "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
+                             " ... + a[n]*b[n] where n is between 0 and 3">;
+}
+
+// Thread Operations
+let OpOverloadTypes = [llvm_i32_ty] in {
+  let OpClass =  threadId in
+    def ThreadId : DXILOpMapping<93, int_dx_thread_id,
+                                 "Reads the thread ID">;
+  let OpClass =  groupId in
+    def GroupId  : DXILOpMapping<94, int_dx_group_id,
+                                 "Reads the group ID (SV_GroupID)">;
+  let OpClass =  threadIdInGroup in
+    def ThreadIdInGroup : DXILOpMapping<95, int_dx_thread_id_in_group,
                                     "Reads the thread ID within the group "
                                     "(SV_GroupThreadID)">;
-def FlattenedThreadIdInGroup : DXILOpMapping<96, flattenedThreadIdInGroup,
-                                             int_dx_flattened_thread_id_in_group,
-                                             "Provides a flattened index for a "
-                                             "given thread within a given "
-                                             "group (SV_GroupIndex)">;
+  let OpClass = flattenedThreadIdInGroup in
+    def FlattenedThreadIdInGroup : DXILOpMapping<96, int_dx_flattened_thread_id_in_group,
+                                                 "Provides a flattened index for a given"
+                                                 " thread within a given group (SV_GroupIndex)">;
+}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index f2504775d557f2..77f0d2cb8c937f 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -20,6 +20,7 @@
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/CodeGenTypes/MachineValueType.h"
 #include "llvm/Support/DXILABI.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
 #include <string>
@@ -41,6 +42,8 @@ struct DXILOperationDesc {
   StringRef Doc;      // the documentation description of this instruction
   SmallVector<Record *> OpTypes; // Vector of operand type records -
                                  // return type is at index 0
+  SmallVector<Record *> OpOverloads; // Vector of fixed types valid for
+                                     // operation overloads
   SmallVector<std::string>
       OpAttributes;     // operation attribute represented as strings
   StringRef Intrinsic;  // The llvm intrinsic map to OpName. Default is "" which
@@ -91,15 +94,12 @@ static ParameterKind getParameterKind(const Record *R) {
     return ParameterKind::I32;
   case MVT::fAny:
   case MVT::iAny:
+  case MVT::Any:
     return ParameterKind::Overload;
-  case MVT::Other:
-    // Handle DXIL-specific overload types
-    if (R->getValueAsInt("isHalfOrFloat") || R->getValueAsInt("isI16OrI32")) {
-      return ParameterKind::Overload;
-    }
-    LLVM_FALLTHROUGH;
   default:
-    llvm_unreachable("Support for specified DXIL Type not yet implemented");
+    report_fatal_error(
+        "Support for specified parameter type not yet implemented",
+        /*gen_crash_diag*/ false);
   }
 }
 
@@ -113,9 +113,9 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   OpCode = R->getValueAsInt("OpCode");
 
   Doc = R->getValueAsString("Doc");
-
-  auto TypeRecs = R->getValueAsListOfDefs("OpTypes");
-  unsigned TypeRecsSize = TypeRecs.size();
+  Record *OpClassRec = R->getValueAsDef("OpClass");
+  auto ParamTypeRecs = OpClassRec->getValueAsListOfDefs("OpSignature");
+  unsigned ParamTypeRecsSize = ParamTypeRecs.size();
   // Populate OpTypes with return type and parameter types
 
   // Parameter indices of overloaded parameters.
@@ -124,32 +124,32 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   // the comment before the definition of class LLVMMatchType in
   // llvm/IR/Intrinsics.td
   SmallVector<int> OverloadParamIndices;
-  for (unsigned i = 0; i < TypeRecsSize; i++) {
-    auto TR = TypeRecs[i];
+  for (unsigned I = 0; I < ParamTypeRecsSize; I++) {
+    auto TR = ParamTypeRecs[I];
     // Track operation parameter indices of any overload types
-    auto isAny = TR->getValueAsInt("isAny");
-    if (isAny == 1) {
+    auto IsAny = TR->getValueAsInt("isAny");
+    if (IsAny == 1) {
       // TODO: At present it is expected that all overload types in a DXIL Op
       // are of the same type. Hence, OverloadParamIndices will have only one
       // element. This implies we do not need a vector. However, until more
       // (all?) DXIL Ops are added in DXIL.td, a vector is being used to flag
       // cases this assumption would not hold.
       if (!OverloadParamIndices.empty()) {
-        bool knownType = true;
+        bool KnownType = true;
         // Ensure that the same overload type registered earlier is being used
         for (auto Idx : OverloadParamIndices) {
-          if (TR != TypeRecs[Idx]) {
-            knownType = false;
+          if (TR != ParamTypeRecs[Idx]) {
+            KnownType = false;
             break;
           }
         }
-        if (!knownType) {
+        if (!KnownType) {
           report_fatal_error("Specification of multiple differing overload "
                              "parameter types not yet supported",
-                             false);
+                             /*gen_crash_diag*/ false);
         }
       } else {
-        OverloadParamIndices.push_back(i);
+        OverloadParamIndices.push_back(I);
       }
     }
     // Populate OpTypes array according to the type specification
@@ -160,7 +160,7 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
       // Get the parameter index of anonymous type, TR, references
       auto OLParamIndex = TR->getValueAsInt("Number");
       // Resolve and insert the type to that at OLParamIndex
-      OpTypes.emplace_back(TypeRecs[OLParamIndex]);
+      OpTypes.emplace_back(ParamTypeRecs[OLParamIndex]);
     } else {
       // A non-anonymous type. Just record it in OpTypes
       OpTypes.emplace_back(TR);
@@ -172,9 +172,18 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   if (!OverloadParamIndices.empty()) {
     if (OverloadParamIndices.size() > 1)
       report_fatal_error("Multiple overload type specification not supported",
-                         false);
+                         /*gen_crash_diag*/ false);
     OverloadParamIndex = OverloadParamIndices[0];
   }
+
+  // Get valid overload types of the Operation
+  auto OverloadTypeRecs = R->getValueAsListOfDefs("OpOverloadTypes");
+  unsigned OverloadTypeRecsSize = OverloadTypeRecs.size();
+  // Populate OpOverloads with
+  for (unsigned I = 0; I < OverloadTypeRecsSize; I++) {
+    OpOverloads.emplace_back(OverloadTypeRecs[I]);
+  }
+
   // Get the operation class
   OpClass = R->getValueAsDef("OpClass")->getName();
 
@@ -188,10 +197,10 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
     // that of the intrinsic. Deviations are expected to be encoded in TableGen
     // record specification and handled accordingly here. Support to be added
     // as needed.
-    auto IntrPropList = IntrinsicDef->getValueAsListInit("IntrProperties");
+    ListInit *IntrPropList = IntrinsicDef->getValueAsListInit("IntrProperties");
     auto IntrPropListSize = IntrPropList->size();
-    for (unsigned i = 0; i < IntrPropListSize; i++) {
-      OpAttributes.emplace_back(IntrPropList->getElement(i)->getAsString());
+    for (unsigned I = 0; I < IntrPropListSize; I++) {
+      OpAttributes.emplace_back(IntrPropList->getElement(I)->getAsString());
     }
   }
 }
@@ -233,16 +242,9 @@ static std::string getParameterKindStr(ParameterKind Kind) {
   llvm_unreachable("Unknown llvm::dxil::ParameterKind enum");
 }
 
-/// Return a string representation of OverloadKind enum that maps to
-/// input LLVMType record
-/// \param R TableGen def record of class LLVMType
-/// \return std::string string representation of OverloadKind
-
 static std::string getOverloadKindStr(const Record *R) {
-  auto VTRec = R->getValueAsDef("VT");
+  Record *VTRec = R->getValueAsDef("VT");
   switch (getValueType(VTRec)) {
-  case MVT::isVoid:
-    return "OverloadKind::VOID";
   case MVT::f16:
     return "OverloadKind::HALF";
   case MVT::f32:
@@ -259,24 +261,28 @@ static std::string getOverloadKindStr(const Record *R) {
     return "OverloadKind::I32";
   case MVT::i64:
     return "OverloadKind::I64";
-  case MVT::iAny:
-    return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64";
-  case MVT::fAny:
-    return "OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE";
-  case MVT::Other:
-    // Handle DXIL-specific overload types
-    {
-      if (R->getValueAsInt("isHalfOrFloat")) {
-        return "OverloadKind::HALF | OverloadKind::FLOAT";
-      } else if (R->getValueAsInt("isI16OrI32")) {
-        return "OverloadKind::I16 | OverloadKind::I32";
-      }
-    }
-    LLVM_FALLTHROUGH;
   default:
-    llvm_unreachable(
-        "Support for specified parameter OverloadKind not yet implemented");
+    llvm_unreachable("Support for specified fixed type option for overload "
+                     "type not supported");
+  }
+}
+/// Return a string representation of OverloadKind enum that maps to
+/// input LLVMType record
+/// \param R TableGen def record of class LLVMType
+/// \return std::string string representation of OverloadKind
+
+static std::string
+getOverloadKindStrs(const SmallVector<Record *> OverloadTys) {
+  if (OverloadTys.empty()) {
+    return {};
+  }
+  std::string OverloadString = "";
+  auto Iter = OverloadTys.begin();
+  OverloadString.append(getOverloadKindStr(*Iter++));
+  for (; Iter != OverloadTys.end(); ++Iter) {
+    OverloadString.append(" | ").append(getOverloadKindStr(*Iter));
   }
+  return OverloadString;
 }
 
 /// Emit Enums of DXIL Ops
@@ -411,7 +417,7 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
     OS << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
        << ", OpCodeClass::" << Op.OpClass << ", "
        << OpClassStrings.get(Op.OpClass.data()) << ", "
-       << getOverloadKindStr(Op.OpTypes[OLParamIdx]) << ", "
+       << getOverloadKindStrs(Op.OpOverloads) << ", "
        << emitDXILOperationAttr(Op.OpAttributes) << ", "
        << Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
        << Parameters.get(ParameterMap[Op.OpClass]) << " },\n";
@@ -477,7 +483,7 @@ static void EmitDXILOperation(RecordKeeper &Records, raw_ostream &OS) {
   OS << "\n";
   // Get all DXIL Ops to intrinsic mapping records
   std::vector<Record *> OpIntrMaps =
-      Records.getAllDerivedDefinitions("DXILOpMapping");
+      Records.getAllDerivedDefinitions("DXILOpMappingBase");
   std::vector<DXILOperationDesc> DXILOps;
   for (auto *Record : OpIntrMaps) {
     DXILOps.emplace_back(DXILOperationDesc(Record));

>From 47df485d76cf4c7c4c16ce5278641ff730ff4a71 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 29 Mar 2024 18:33:05 -0400
Subject: [PATCH 2/6] Add support to specify overload types specific to Shader
 Model for TableGen records of DXIL Opeartions.

---
 llvm/lib/Target/DirectX/DXIL.td            | 240 +++++++++++----------
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp  |  48 ++++-
 llvm/lib/Target/DirectX/DXILOpBuilder.h    |  15 +-
 llvm/lib/Target/DirectX/DXILOpLowering.cpp |  23 +-
 llvm/test/CodeGen/DirectX/abs.ll           |   2 +-
 llvm/test/CodeGen/DirectX/ceil.ll          |   3 +-
 llvm/test/CodeGen/DirectX/ceil_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/clamp.ll         |   2 +-
 llvm/test/CodeGen/DirectX/cos.ll           |   2 +-
 llvm/test/CodeGen/DirectX/cos_error.ll     |   2 +-
 llvm/test/CodeGen/DirectX/dot2_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/dot3_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/dot4_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/exp.ll           |   2 +-
 llvm/test/CodeGen/DirectX/exp2_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/fabs.ll          |   3 +-
 llvm/test/CodeGen/DirectX/fdot.ll          |   2 +-
 llvm/test/CodeGen/DirectX/floor.ll         |   2 +-
 llvm/test/CodeGen/DirectX/floor_error.ll   |   2 +-
 llvm/test/CodeGen/DirectX/fmax.ll          |   2 +-
 llvm/test/CodeGen/DirectX/fmin.ll          |   2 +-
 llvm/test/CodeGen/DirectX/frac_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/idot.ll          |   2 +-
 llvm/test/CodeGen/DirectX/isinf.ll         |   2 +-
 llvm/test/CodeGen/DirectX/isinf_error.ll   |   2 +-
 llvm/test/CodeGen/DirectX/log.ll           |   2 +-
 llvm/test/CodeGen/DirectX/log10.ll         |   2 +-
 llvm/test/CodeGen/DirectX/log2.ll          |   2 +-
 llvm/test/CodeGen/DirectX/log2_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/pow.ll           |   2 +-
 llvm/test/CodeGen/DirectX/reversebits.ll   |   2 +-
 llvm/test/CodeGen/DirectX/round.ll         |   2 +-
 llvm/test/CodeGen/DirectX/round_error.ll   |   2 +-
 llvm/test/CodeGen/DirectX/rsqrt.ll         |   2 +-
 llvm/test/CodeGen/DirectX/rsqrt_error.ll   |   2 +-
 llvm/test/CodeGen/DirectX/sin.ll           |   6 +-
 llvm/test/CodeGen/DirectX/sin_error.ll     |   2 +-
 llvm/test/CodeGen/DirectX/sin_sm_error.ll  |  15 ++
 llvm/test/CodeGen/DirectX/smax.ll          |   2 +-
 llvm/test/CodeGen/DirectX/smin.ll          |   2 +-
 llvm/test/CodeGen/DirectX/sqrt.ll          |   2 +-
 llvm/test/CodeGen/DirectX/sqrt_error.ll    |   2 +-
 llvm/test/CodeGen/DirectX/trunc.ll         |   2 +-
 llvm/test/CodeGen/DirectX/trunc_error.ll   |   2 +-
 llvm/test/CodeGen/DirectX/umax.ll          |   2 +-
 llvm/test/CodeGen/DirectX/umin.ll          |   2 +-
 llvm/utils/TableGen/DXILEmitter.cpp        |  52 +++--
 47 files changed, 298 insertions(+), 183 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/sin_sm_error.ll

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 082e7596aded35..0ea0f016fd6c74 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -13,9 +13,33 @@
 
 include "llvm/IR/Intrinsics.td"
 
-// Absttraction of DXIL Operation class.
+// Abstract class to demarcate minimum Shader model version required
+// to support DXIL Op
+class DXILShaderModel<int major, int minor> {
+  int MajorAndMinor = !add(!mul(major, 10), minor);
+}
+
+// Valid minimum Shader model version records
+
+// Shader Mode 6.x
+foreach i = 0...9 in {
+  def SM6_#i : DXILShaderModel<6, i>;
+}
+// Shader Mode 7.x - for now 7.0 is defined. Extend as needed
+foreach i = 0 in {
+  def SM7_#i : DXILShaderModel<7, i>;
+}
+
+// Abstraction of class mapping valid DXIL Op overloads the minimum
+// version of Shader Model they are supported
+class DXILOpOverload<DXILShaderModel minsm, list<LLVMType> overloads> {
+  DXILShaderModel ShaderModel = minsm;
+  list<LLVMType> OpOverloads = overloads;
+}
+
+// Abstraction of DXIL Operation class.
 // It encapsulates an associated function signature viz.,
-// returnty(param1, param2, ...) represented as a list of LLVMTypes.
+// returnTy(param1Ty, param2Ty, ...) represented as a list of LLVMTypes.
 // DXIL Ops that belong to a DXILOpClass record the signature of that
 // DXILOpClass
 
@@ -30,21 +54,21 @@ class DXILOpClass<list<LLVMType> OpSig> {
 // name has the format "dx.op.<class-name>.<return-type>", in most cases.
 
 defset list<DXILOpClass> OpClasses = {
-//  def acceptHitAndEndSearch : DXILOpClass;
+def acceptHitAndEndSearch : DXILOpClass<[llvm_void_ty]>;
 //  def allocateNodeOutputRecords : DXILOpClass;
-//  def allocateRayQuery : DXILOpClass;
+def allocateRayQuery : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
 //  def annotateHandle : DXILOpClass;
 //  def annotateNodeHandle : DXILOpClass;
 //  def annotateNodeRecordHandle : DXILOpClass;
 //  def atomicBinOp : DXILOpClass;
 //  def atomicCompareExchange : DXILOpClass;
-//  def attributeAtVertex : DXILOpClass;
-//  def barrier : DXILOpClass;
+def attributeAtVertex : DXILOpClass<[llvm_any_ty, llvm_i32_ty, llvm_i32_ty, llvm_i8_ty, llvm_i8_ty]>;
+def barrier : DXILOpClass<[llvm_void_ty, llvm_i32_ty]>;
 //  def barrierByMemoryHandle : DXILOpClass;
-//  def barrierByMemoryType : DXILOpClass;
+def barrierByMemoryType : DXILOpClass<[llvm_void_ty, llvm_i32_ty, llvm_i32_ty]>;
 //  def barrierByNodeRecordHandle : DXILOpClass;
-def binary : DXILOpClass<[llvm_any_ty, llvm_any_ty, llvm_any_ty]>;
-//  def binaryWithCarryOrBorrow : DXILOpClass;
+def binary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>]>;
+def binaryWithCarryOrBorrow : DXILOpClass<[llvm_i32_ty, llvm_any_ty, LLVMMatchType<0>]>;
 //  def binaryWithTwoOuts : DXILOpClass;
 //  def bitcastF16toI16 : DXILOpClass;
 //  def bitcastF32toI32 : DXILOpClass;
@@ -164,7 +188,7 @@ def isSpecialFloat : DXILOpClass<[llvm_i1_ty, llvm_anyfloat_ty]>;
 //  def storeVertexOutput : DXILOpClass;
 //  def tempRegLoad : DXILOpClass;
 //  def tempRegStore : DXILOpClass;
-def tertiary : DXILOpClass<[llvm_any_ty, llvm_any_ty, llvm_any_ty, llvm_any_ty]>;
+def tertiary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 //  def texture2DMSGetSamplePosition : DXILOpClass;
 //  def textureGather : DXILOpClass;
 //  def textureGatherCmp : DXILOpClass;
@@ -175,7 +199,7 @@ def tertiary : DXILOpClass<[llvm_any_ty, llvm_any_ty, llvm_any_ty, llvm_any_ty]>
 def threadId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
 def threadIdInGroup : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
 //  def traceRay : DXILOpClass;
-def unary : DXILOpClass<[llvm_any_ty, llvm_any_ty]>;
+def unary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>]>;
 //  def unaryBits : DXILOpClass;
 //  def unpack4x8 : DXILOpClass;
 //  def viewID : DXILOpClass;
@@ -224,138 +248,128 @@ class DXILOpMappingBase {
   int OpCode = 0;                      // Opcode of DXIL Operation
   DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = ?;         // LLVM Intrinsic DXIL Operation maps to
-  list<LLVMType> OpOverloadTypes = ?;      // Fixed types valid for overload type
+  list<DXILOpOverload> OpOverloadTypes = ?; // Valid overload type
                                        // of DXIL Operation
   string Doc = "";                     // A short description of the operation
 }
 
 class DXILOpMapping<int opCode,
                     Intrinsic intrinsic,
+                    list<DXILOpOverload> overloadTypes,
                     string doc> : DXILOpMappingBase {
   int OpCode = opCode;
   Intrinsic LLVMIntrinsic = intrinsic;
+  list<DXILOpOverload> OpOverloadTypes = overloadTypes;
   string Doc = doc;
 }
 
 // Concrete definitions of DXIL Operation mapping to corresponding LLVM intrinsic
 
 // IsSpecialFloat Class
-let OpClass = isSpecialFloat, OpOverloadTypes = [llvm_half_ty, llvm_float_ty] in
-  def IsInf : DXILOpMapping<9,  int_dx_isinf,
+let OpClass = isSpecialFloat in {
+  def IsInf : DXILOpMapping<9,  int_dx_isinf, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
                            "Determines if the specified value is infinite.">;
+}
 
 // Unary Class
 let OpClass = unary in {
-  let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in
-    def Abs : DXILOpMapping<6, int_fabs,
-                            "Returns the absolute value of the input.">;
-
-  let OpOverloadTypes = [llvm_half_ty, llvm_float_ty] in {
-    def Cos  : DXILOpMapping<12, int_cos,
-                             "Returns cosine(theta) for theta in radians.">;
-    def Sin  : DXILOpMapping<13, int_sin,
-                             "Returns sine(theta) for theta in radians.">;
-    def Exp2 : DXILOpMapping<21, int_exp2,
-                             "Returns the base 2 exponential, or 2**x, of the"
-                             " specified value. exp2(x) = 2**x.">;
-    def Frac : DXILOpMapping<22, int_dx_frac,
-                             "Returns a fraction from 0 to 1 that represents the"
-                             " decimal part of the input.">;
-    def Log2 : DXILOpMapping<23, int_log2,
-                             "Returns the base-2 logarithm of the specified value.">;
-    def Sqrt : DXILOpMapping<24, int_sqrt,
-                             "Returns the square root of the specified floating-point"
-                             "value, per component.">;
-    def RSqrt : DXILOpMapping<25, int_dx_rsqrt,
-                             "Returns the reciprocal of the square root of the"
-                             " specified value. rsqrt(x) = 1 / sqrt(x).">;
-    def Round : DXILOpMapping<26, int_round,
-                              "Returns the input rounded to the nearest integer"
-                              "within a floating-point type.">;
-    def Floor : DXILOpMapping<27, int_floor,
-                              "Returns the largest integer that is less than or equal to the input.">;
-    def Trunc : DXILOpMapping<29, int_trunc,
-                         "Returns the specified value truncated to the integer component.">;
-  }
-  let OpOverloadTypes = [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty] in {
-    def Rbits : DXILOpMapping<30, int_bitreverse,
-                         "Returns the specified value with its bits reversed.">;
+  def Abs : DXILOpMapping<6, int_fabs, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
+                          "Returns the absolute value of the input.">;
 
-  }
+  def Cos  : DXILOpMapping<12, int_cos, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                          "Returns cosine(theta) for theta in radians.">;
+  def Sin  : DXILOpMapping<13, int_sin, [DXILOpOverload<SM6_3, [llvm_half_ty, llvm_float_ty]>,
+                                         DXILOpOverload<SM6_0, [llvm_float_ty]>],
+                           "Returns sine(theta) for theta in radians.">;
+  def Exp2 : DXILOpMapping<21, int_exp2, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                           "Returns the base 2 exponential, or 2**x, of the"
+                           " specified value. exp2(x) = 2**x.">;
+  def Frac : DXILOpMapping<22, int_dx_frac, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns a fraction from 0 to 1 that represents the"
+                            " decimal part of the input.">;
+  def Log2 : DXILOpMapping<23, int_log2, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                           "Returns the base-2 logarithm of the specified value.">;
+  def Sqrt : DXILOpMapping<24, int_sqrt, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                           "Returns the square root of the specified floating-point"
+                           "value, per component.">;
+  def RSqrt : DXILOpMapping<25, int_dx_rsqrt, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the reciprocal of the square root of the"
+                            " specified value. rsqrt(x) = 1 / sqrt(x).">;
+  def Round : DXILOpMapping<26, int_roundeven, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the input rounded to the nearest integer"
+                            "within a floating-point type.">;
+  def Floor : DXILOpMapping<27, int_floor, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the largest integer that is less than or equal to the input.">;
+  def Ceil  : DXILOpMapping<28, int_ceil, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the smallest integer that is greater than or equal to the input.">;
+  def Trunc : DXILOpMapping<29, int_trunc, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "Returns the specified value truncated to the integer component.">;
+  def Rbits : DXILOpMapping<30, int_bitreverse, [DXILOpOverload<SM6_0, [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                            "Returns the specified value with its bits reversed.">;
 }
 
-
 // Binary Class
 let OpClass = binary in {
-  // Float overloads
-  let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in {
-    def FMax : DXILOpMapping<35, int_maxnum,
-                             "Float maximum. FMax(a,b) = a > b ? a : b">;
-    def FMin : DXILOpMapping<36, int_minnum,
-                             "Float minimum. FMin(a,b) = a < b ? a : b">;
-  }
-  // Int overloads
-  let OpOverloadTypes =  [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty] in {
-    def SMax : DXILOpMapping<37, int_smax,
-                             "Signed integer maximum. SMax(a,b) = a > b ? a : b">;
-    def SMin : DXILOpMapping<38, int_smin,
-                             "Signed integer minimum. SMin(a,b) = a < b ? a : b">;
-    def UMax : DXILOpMapping<39, int_umax,
-                             "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
-    def UMin : DXILOpMapping<40, int_umin,
-                             "Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
-  }
+// Float overloads
+  def FMax : DXILOpMapping<35, int_maxnum, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
+                           "Float maximum. FMax(a,b) = a > b ? a : b">;
+  def FMin : DXILOpMapping<36, int_minnum, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
+                           "Float minimum. FMin(a,b) = a < b ? a : b">;
+// Int overloads
+  def SMax : DXILOpMapping<37, int_smax, [DXILOpOverload<SM6_0,[llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                           "Signed integer maximum. SMax(a,b) = a > b ? a : b">;
+  def SMin : DXILOpMapping<38, int_smin, [DXILOpOverload<SM6_0,[llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                           "Signed integer minimum. SMin(a,b) = a < b ? a : b">;
+  def UMax : DXILOpMapping<39, int_umax, [DXILOpOverload<SM6_0,[llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                           "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
+  def UMin : DXILOpMapping<40, int_umin, [DXILOpOverload<SM6_0,[llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                           "Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
 }
 
 // Tertiary Class
 let OpClass = tertiary in {
-  // Float overloads
-  let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in {
-    def FMad : DXILOpMapping<46, int_fmuladd,
-                             "Floating point arithmetic multiply/add operation."
-                             " fmad(m,a,b) = m * a + b.">;
-  }
-  // Int overloads
-  let OpOverloadTypes =  [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty] in {
-    def IMad : DXILOpMapping<48, int_dx_imad,
-                             "Signed integer arithmetic multiply/add operation."
-                             " imad(m,a,b) = m * a + b.">;
-    def UMad : DXILOpMapping<49, int_dx_umad,
-                         "Unsigned integer arithmetic multiply/add operation."
-                         " umad(m,a, = m * a + b.">;
-  }
+// Float overloads
+//   let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in {
+  def FMad : DXILOpMapping<46, int_fmuladd, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
+                            "Floating point arithmetic multiply/add operation."
+                            " fmad(m,a,b) = m * a + b.">;
+// Int overloads
+def IMad : DXILOpMapping<48, int_dx_imad, [DXILOpOverload<SM6_0, [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                         "Signed integer arithmetic multiply/add operation."
+                          " imad(m,a,b) = m * a + b.">;
+def UMad : DXILOpMapping<49, int_dx_umad, [DXILOpOverload<SM6_0, [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty]>],
+                        "Unsigned integer arithmetic multiply/add operation."
+                        " umad(m,a, = m * a + b.">;
 }
 
 // Dot Operations
-let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty] in {
-  let OpClass = dot2 in
-    def Dot2 : DXILOpMapping<54, int_dx_dot2,
-                             "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
-                             " ... + a[n]*b[n] where n is between 0 and 1">;
-  let OpClass = dot3 in
-    def Dot3 : DXILOpMapping<55, int_dx_dot3,
-                             "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
-                             " ... + a[n]*b[n] where n is between 0 and 2">;
-  let OpClass = dot4 in
-    def Dot4 : DXILOpMapping<56, int_dx_dot4,
-                             "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
-                             " ... + a[n]*b[n] where n is between 0 and 3">;
-}
+// let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty] in
+let OpClass = dot2 in
+  def Dot2 : DXILOpMapping<54, int_dx_dot2, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                          "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
+                          " ... + a[n]*b[n] where n is between 0 and 1">;
+let OpClass = dot3 in
+  def Dot3 : DXILOpMapping<55, int_dx_dot3, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
+                           " ... + a[n]*b[n] where n is between 0 and 2">;
+let OpClass = dot4 in
+   def Dot4 : DXILOpMapping<56, int_dx_dot4, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
+                            "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
+                            " ... + a[n]*b[n] where n is between 0 and 3">;
 
 // Thread Operations
-let OpOverloadTypes = [llvm_i32_ty] in {
-  let OpClass =  threadId in
-    def ThreadId : DXILOpMapping<93, int_dx_thread_id,
-                                 "Reads the thread ID">;
-  let OpClass =  groupId in
-    def GroupId  : DXILOpMapping<94, int_dx_group_id,
-                                 "Reads the group ID (SV_GroupID)">;
-  let OpClass =  threadIdInGroup in
-    def ThreadIdInGroup : DXILOpMapping<95, int_dx_thread_id_in_group,
-                                    "Reads the thread ID within the group "
-                                    "(SV_GroupThreadID)">;
-  let OpClass = flattenedThreadIdInGroup in
-    def FlattenedThreadIdInGroup : DXILOpMapping<96, int_dx_flattened_thread_id_in_group,
-                                                 "Provides a flattened index for a given"
-                                                 " thread within a given group (SV_GroupIndex)">;
-}
+let OpClass =  threadId in
+  def ThreadId : DXILOpMapping<93, int_dx_thread_id, [DXILOpOverload<SM6_0, [llvm_i32_ty]>],
+                              "Reads the thread ID">;
+let OpClass =  groupId in
+  def GroupId  : DXILOpMapping<94, int_dx_group_id, [DXILOpOverload<SM6_0, [llvm_i32_ty]>],
+                               "Reads the group ID (SV_GroupID)">;
+let OpClass =  threadIdInGroup in
+  def ThreadIdInGroup : DXILOpMapping<95, int_dx_thread_id_in_group, [DXILOpOverload<SM6_0, [llvm_i32_ty]>],
+                                     "Reads the thread ID within the group "
+                                     "(SV_GroupThreadID)">;
+let OpClass = flattenedThreadIdInGroup in
+  def FlattenedThreadIdInGroup : DXILOpMapping<96, int_dx_flattened_thread_id_in_group,
+                                               [DXILOpOverload<SM6_0, [llvm_i32_ty]>],
+                                                "Provides a flattened index for a given"
+                                                " thread within a given group (SV_GroupIndex)">;
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 0b3982ea0f438a..1a4f8c709c0fd2 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -15,6 +15,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/Support/DXILABI.h"
 #include "llvm/Support/ErrorHandling.h"
+#include <string>
 
 using namespace llvm;
 using namespace llvm::dxil;
@@ -123,6 +124,11 @@ static std::string getTypeName(OverloadKind Kind, Type *Ty) {
   }
 }
 
+struct OpSMOverloadProp {
+  uint16_t ShaderModelVer;
+  uint16_t ValidTys;
+};
+
 // Static properties.
 struct OpCodeProperty {
   dxil::OpCode OpCode;
@@ -131,7 +137,7 @@ struct OpCodeProperty {
   dxil::OpCodeClass OpCodeClass;
   // Offset in DXILOpCodeClassNameTable.
   unsigned OpCodeClassNameOffset;
-  uint16_t OverloadTys;
+  std::vector<OpSMOverloadProp> OverloadProp;
   llvm::Attribute::AttrKind FuncAttr;
   int OverloadParamIndex;        // parameter index which control the overload.
                                  // When < 0, should be only 1 overload type.
@@ -249,16 +255,38 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
       ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
 }
 
+static uint16_t getValidOverloadMask(const OpCodeProperty *Prop,
+                                     uint32_t SMVer) {
+  uint16_t ValidTyMask = 0;
+  // std::vector Prop->OverloadProp is in ascending order of SM Version
+  // Overloads of highest SM version that is not greater than SMVer
+  // are the ones that are valid for SMVer.
+  for (auto OL : Prop->OverloadProp) {
+    if (OL.ShaderModelVer <= SMVer) {
+      ValidTyMask = OL.ValidTys;
+    } else {
+      break;
+    }
+  }
+  return ValidTyMask;
+}
+
 namespace llvm {
 namespace dxil {
 
-CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
-                                          Type *OverloadTy,
+CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, uint32_t SMVer,
+                                          Type *ReturnTy, Type *OverloadTy,
                                           SmallVector<Value *> Args) {
   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
+  uint16_t ValidTyMask = getValidOverloadMask(Prop, SMVer);
 
+  if (ValidTyMask == 0) {
+    report_fatal_error(StringRef(std::to_string(SMVer).append(
+                           ": Unhandled Shader Model Version")),
+                       /*gen_crash_diag*/ false);
+  }
   OverloadKind Kind = getOverloadKind(OverloadTy);
-  if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
+  if ((ValidTyMask & (uint16_t)Kind) == 0) {
     report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);
   }
 
@@ -276,14 +304,22 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
   return B.CreateCall(DXILFn, Args);
 }
 
-Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
+Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, uint32_t SMVer,
+                                   FunctionType *FT) {
 
   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
   // If DXIL Op has no overload parameter, just return the
   // precise return type specified.
   if (Prop->OverloadParamIndex < 0) {
     auto &Ctx = FT->getContext();
-    switch (Prop->OverloadTys) {
+    uint16_t ValidTyMask = getValidOverloadMask(Prop, SMVer);
+    if (ValidTyMask == 0) {
+      report_fatal_error(StringRef(std::to_string(SMVer).append(
+                             ": Unhandled Shader Model Version")),
+                         /*gen_crash_diag*/ false);
+    }
+
+    switch (ValidTyMask) {
     case OverloadKind::VOID:
       return Type::getVoidTy(Ctx);
     case OverloadKind::HALF:
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 5babeae470178b..1e15286c810a8a 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -14,6 +14,7 @@
 
 #include "DXILConstants.h"
 #include "llvm/ADT/SmallVector.h"
+#include <cstdint>
 
 namespace llvm {
 class Module;
@@ -30,13 +31,17 @@ class DXILOpBuilder {
 public:
   DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {}
   /// Create an instruction that calls DXIL Op with return type, specified
-  /// opcode, and call arguments. \param OpCode Opcode of the DXIL Op call
-  /// constructed \param ReturnTy Return type of the DXIL Op call constructed
+  /// opcode, and call arguments.
+  ///
+  /// \param OpCode Opcode of the DXIL Op call constructed
+  /// \param SMVer Shader Model Version of DXIL Op call to construct
+  /// \param ReturnTy Return type of the DXIL Op call constructed
   /// \param OverloadTy Overload type of the DXIL Op call constructed
   /// \return DXIL Op call constructed
-  CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
-                             Type *OverloadTy, SmallVector<Value *> Args);
-  Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
+  CallInst *createDXILOpCall(dxil::OpCode OpCode, uint32_t SMVer,
+                             Type *ReturnTy, Type *OverloadTy,
+                             SmallVector<Value *> Args);
+  Type *getOverloadTy(dxil::OpCode OpCode, uint32_t SMVer, FunctionType *FT);
   static const char *getOpCodeName(dxil::OpCode DXILOp);
 
 private:
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index f09e322f88e1fd..c3217d51ece1b9 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -22,6 +22,7 @@
 #include "llvm/IR/IntrinsicsDirectX.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/MC/TargetRegistry.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/ErrorHandling.h"
 
@@ -72,10 +73,26 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
   return NewOperands;
 }
 
+static uint32_t getShaderModelVer(Module &M) {
+  std::string TTStr = M.getTargetTriple();
+  std::string Error;
+  auto Target = TargetRegistry::lookupTarget(TTStr, Error);
+  if (!Target) {
+    if (TTStr.empty()) {
+      report_fatal_error(StringRef(Error), /*gen_crash_diag*/ false);
+    }
+  }
+  auto Major = Triple(TTStr).getOSVersion().getMajor();
+  auto MinorOrErr = Triple(TTStr).getOSVersion().getMinor();
+  uint32_t Minor = MinorOrErr.has_value() ? *MinorOrErr : 0;
+  return ((Major * 10) + Minor);
+}
+
 static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
   IRBuilder<> B(M.getContext());
   DXILOpBuilder DXILB(M, B);
-  Type *OverloadTy = DXILB.getOverloadTy(DXILOp, F.getFunctionType());
+  uint32_t SMVer = getShaderModelVer(M);
+  Type *OverloadTy = DXILB.getOverloadTy(DXILOp, SMVer, F.getFunctionType());
   for (User *U : make_early_inc_range(F.users())) {
     CallInst *CI = dyn_cast<CallInst>(U);
     if (!CI)
@@ -91,8 +108,8 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
     } else
       Args.append(CI->arg_begin(), CI->arg_end());
 
-    CallInst *DXILCI =
-        DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args);
+    CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, SMVer, F.getReturnType(),
+                                              OverloadTy, Args);
 
     CI->replaceAllUsesWith(DXILCI);
     CI->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/abs.ll b/llvm/test/CodeGen/DirectX/abs.ll
index 822580e8c089af..73bc00042ac161 100644
--- a/llvm/test/CodeGen/DirectX/abs.ll
+++ b/llvm/test/CodeGen/DirectX/abs.ll
@@ -1,5 +1,5 @@
 ; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
-; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+; RUN: opt -S  -dxil-op-lower  -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
 
 ; Make sure dxil operation function calls for abs are generated for int16_t/int/int64_t.
 
diff --git a/llvm/test/CodeGen/DirectX/ceil.ll b/llvm/test/CodeGen/DirectX/ceil.ll
index 15854714678014..c7583cb70ba19f 100644
--- a/llvm/test/CodeGen/DirectX/ceil.ll
+++ b/llvm/test/CodeGen/DirectX/ceil.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for ceil are generated for float and half.
 
@@ -18,3 +18,4 @@ entry:
 
 declare half @llvm.ceil.f16(half)
 declare float @llvm.ceil.f32(float)
+
diff --git a/llvm/test/CodeGen/DirectX/ceil_error.ll b/llvm/test/CodeGen/DirectX/ceil_error.ll
index 1b554d8715566e..da6f083550186c 100644
--- a/llvm/test/CodeGen/DirectX/ceil_error.ll
+++ b/llvm/test/CodeGen/DirectX/ceil_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation ceil does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/clamp.ll b/llvm/test/CodeGen/DirectX/clamp.ll
index f122313b8d7dcc..2f29e4479f9ca1 100644
--- a/llvm/test/CodeGen/DirectX/clamp.ll
+++ b/llvm/test/CodeGen/DirectX/clamp.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for clamp/uclamp are generated for half/float/double/i16/i32/i64.
 
diff --git a/llvm/test/CodeGen/DirectX/cos.ll b/llvm/test/CodeGen/DirectX/cos.ll
index 00f2e2c3f6e5ab..72f4bfca23f9d5 100644
--- a/llvm/test/CodeGen/DirectX/cos.ll
+++ b/llvm/test/CodeGen/DirectX/cos.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for cos are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/cos_error.ll b/llvm/test/CodeGen/DirectX/cos_error.ll
index a074f5b493dfd6..6bb85a7cec1e30 100644
--- a/llvm/test/CodeGen/DirectX/cos_error.ll
+++ b/llvm/test/CodeGen/DirectX/cos_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation cos does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/dot2_error.ll b/llvm/test/CodeGen/DirectX/dot2_error.ll
index a27bfaedacd573..54780d18e71fb4 100644
--- a/llvm/test/CodeGen/DirectX/dot2_error.ll
+++ b/llvm/test/CodeGen/DirectX/dot2_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation dot2 does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload
diff --git a/llvm/test/CodeGen/DirectX/dot3_error.ll b/llvm/test/CodeGen/DirectX/dot3_error.ll
index eb69fb145038aa..242716b0b71bad 100644
--- a/llvm/test/CodeGen/DirectX/dot3_error.ll
+++ b/llvm/test/CodeGen/DirectX/dot3_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation dot3 does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload
diff --git a/llvm/test/CodeGen/DirectX/dot4_error.ll b/llvm/test/CodeGen/DirectX/dot4_error.ll
index 5cd632684c0c01..731adda153def8 100644
--- a/llvm/test/CodeGen/DirectX/dot4_error.ll
+++ b/llvm/test/CodeGen/DirectX/dot4_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation dot4 does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload
diff --git a/llvm/test/CodeGen/DirectX/exp.ll b/llvm/test/CodeGen/DirectX/exp.ll
index fdafc1438cf0e8..f67e2744c4ee34 100644
--- a/llvm/test/CodeGen/DirectX/exp.ll
+++ b/llvm/test/CodeGen/DirectX/exp.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for exp are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/exp2_error.ll b/llvm/test/CodeGen/DirectX/exp2_error.ll
index 6b9126785fd4b8..4d13f936eb6be2 100644
--- a/llvm/test/CodeGen/DirectX/exp2_error.ll
+++ b/llvm/test/CodeGen/DirectX/exp2_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation exp2 does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload
diff --git a/llvm/test/CodeGen/DirectX/fabs.ll b/llvm/test/CodeGen/DirectX/fabs.ll
index 3b3f8aa9a4a928..1b3e91dcfb30f6 100644
--- a/llvm/test/CodeGen/DirectX/fabs.ll
+++ b/llvm/test/CodeGen/DirectX/fabs.ll
@@ -1,8 +1,7 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for abs are generated for float, half, and double.
 
-
 ; CHECK-LABEL: fabs_half
 define noundef half @fabs_half(half noundef %a) {
 entry:
diff --git a/llvm/test/CodeGen/DirectX/fdot.ll b/llvm/test/CodeGen/DirectX/fdot.ll
index 3e13b2ad2650c8..bad65cfb1e561f 100644
--- a/llvm/test/CodeGen/DirectX/fdot.ll
+++ b/llvm/test/CodeGen/DirectX/fdot.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s
+; RUN: opt -S  -dxil-op-lower  -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for dot are generated for int/uint vectors.
 
diff --git a/llvm/test/CodeGen/DirectX/floor.ll b/llvm/test/CodeGen/DirectX/floor.ll
index b033e2eaa491e7..f667cab4aa249b 100644
--- a/llvm/test/CodeGen/DirectX/floor.ll
+++ b/llvm/test/CodeGen/DirectX/floor.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for floor are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/floor_error.ll b/llvm/test/CodeGen/DirectX/floor_error.ll
index 3b51a4b543b7f6..e3190e5afb63fa 100644
--- a/llvm/test/CodeGen/DirectX/floor_error.ll
+++ b/llvm/test/CodeGen/DirectX/floor_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation floor does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/fmax.ll b/llvm/test/CodeGen/DirectX/fmax.ll
index aff722c29309c0..05852ee33486d1 100644
--- a/llvm/test/CodeGen/DirectX/fmax.ll
+++ b/llvm/test/CodeGen/DirectX/fmax.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for fmax are generated for half/float/double.
 
diff --git a/llvm/test/CodeGen/DirectX/fmin.ll b/llvm/test/CodeGen/DirectX/fmin.ll
index 2f7c209f0278ae..1c6c7ca3f2e38a 100644
--- a/llvm/test/CodeGen/DirectX/fmin.ll
+++ b/llvm/test/CodeGen/DirectX/fmin.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for fmin are generated for half/float/double.
 
diff --git a/llvm/test/CodeGen/DirectX/frac_error.ll b/llvm/test/CodeGen/DirectX/frac_error.ll
index ebce76105ad4d7..1bc3558ab0c9a5 100644
--- a/llvm/test/CodeGen/DirectX/frac_error.ll
+++ b/llvm/test/CodeGen/DirectX/frac_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation frac does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/idot.ll b/llvm/test/CodeGen/DirectX/idot.ll
index 9f89a8d6d340d5..8fad7b00700f5a 100644
--- a/llvm/test/CodeGen/DirectX/idot.ll
+++ b/llvm/test/CodeGen/DirectX/idot.ll
@@ -1,5 +1,5 @@
 ; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
-; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+; RUN: opt -S  -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library  %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
 
 ; Make sure dxil operation function calls for dot are generated for int/uint vectors.
 
diff --git a/llvm/test/CodeGen/DirectX/isinf.ll b/llvm/test/CodeGen/DirectX/isinf.ll
index e2975da90bfc1b..bbacaa0f99bac6 100644
--- a/llvm/test/CodeGen/DirectX/isinf.ll
+++ b/llvm/test/CodeGen/DirectX/isinf.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s
 
 ; Make sure dxil operation function calls for isinf are generated for float and half.
 ; CHECK: call i1 @dx.op.isSpecialFloat.f32(i32 9, float %{{.*}})
diff --git a/llvm/test/CodeGen/DirectX/isinf_error.ll b/llvm/test/CodeGen/DirectX/isinf_error.ll
index 95b2d0cabcc43b..39b83554d74d0e 100644
--- a/llvm/test/CodeGen/DirectX/isinf_error.ll
+++ b/llvm/test/CodeGen/DirectX/isinf_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation isinf does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/log.ll b/llvm/test/CodeGen/DirectX/log.ll
index 172c3bfed3b770..36344cf7a5f6dd 100644
--- a/llvm/test/CodeGen/DirectX/log.ll
+++ b/llvm/test/CodeGen/DirectX/log.ll
@@ -1,5 +1,5 @@
 ; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
-; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+; RUN: opt -S  -dxil-op-lower  -mtriple=dxil-pc-shadermodel6.3-library  %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
 
 ; Make sure dxil operation function calls for log are generated.
 
diff --git a/llvm/test/CodeGen/DirectX/log10.ll b/llvm/test/CodeGen/DirectX/log10.ll
index d4f827a0d1af83..8e40ccd8d13313 100644
--- a/llvm/test/CodeGen/DirectX/log10.ll
+++ b/llvm/test/CodeGen/DirectX/log10.ll
@@ -1,5 +1,5 @@
 ; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
-; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+; RUN: opt -S  -dxil-op-lower  -mtriple=dxil-pc-shadermodel6.3-library  %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
 
 ; Make sure dxil operation function calls for log10 are generated.
 
diff --git a/llvm/test/CodeGen/DirectX/log2.ll b/llvm/test/CodeGen/DirectX/log2.ll
index 2164d4db9396d1..d6a7ba0b7dda75 100644
--- a/llvm/test/CodeGen/DirectX/log2.ll
+++ b/llvm/test/CodeGen/DirectX/log2.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for log2 are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/log2_error.ll b/llvm/test/CodeGen/DirectX/log2_error.ll
index a26f6e8c3117f5..b8876854d389fb 100644
--- a/llvm/test/CodeGen/DirectX/log2_error.ll
+++ b/llvm/test/CodeGen/DirectX/log2_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation log2 does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/pow.ll b/llvm/test/CodeGen/DirectX/pow.ll
index 25ce0fe731d0ba..4ed886532f5909 100644
--- a/llvm/test/CodeGen/DirectX/pow.ll
+++ b/llvm/test/CodeGen/DirectX/pow.ll
@@ -1,5 +1,5 @@
 ; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
-; RUN: opt -S  -dxil-op-lower  < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
+; RUN: opt -S  -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK
 
 ; Make sure dxil operation function calls for pow are generated.
 
diff --git a/llvm/test/CodeGen/DirectX/reversebits.ll b/llvm/test/CodeGen/DirectX/reversebits.ll
index b6a7a1bc6152e3..1ade57b40100ff 100644
--- a/llvm/test/CodeGen/DirectX/reversebits.ll
+++ b/llvm/test/CodeGen/DirectX/reversebits.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for reversebits are generated for all integer types.
 
diff --git a/llvm/test/CodeGen/DirectX/round.ll b/llvm/test/CodeGen/DirectX/round.ll
index e0a3772ebca8fa..db953fb29c2046 100644
--- a/llvm/test/CodeGen/DirectX/round.ll
+++ b/llvm/test/CodeGen/DirectX/round.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for round are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/round_error.ll b/llvm/test/CodeGen/DirectX/round_error.ll
index 2d27fbb5ee20de..9d2a4e778a9249 100644
--- a/llvm/test/CodeGen/DirectX/round_error.ll
+++ b/llvm/test/CodeGen/DirectX/round_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; This test is expected to fail with the following error
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/rsqrt.ll b/llvm/test/CodeGen/DirectX/rsqrt.ll
index 52af0e62220b3e..054c84483ef826 100644
--- a/llvm/test/CodeGen/DirectX/rsqrt.ll
+++ b/llvm/test/CodeGen/DirectX/rsqrt.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for rsqrt are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/rsqrt_error.ll b/llvm/test/CodeGen/DirectX/rsqrt_error.ll
index 9cd5002c20f7ec..5e29e37113d19f 100644
--- a/llvm/test/CodeGen/DirectX/rsqrt_error.ll
+++ b/llvm/test/CodeGen/DirectX/rsqrt_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation rsqrt does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/sin.ll b/llvm/test/CodeGen/DirectX/sin.ll
index 1f285c433581cf..789f6a73fe6aaa 100644
--- a/llvm/test/CodeGen/DirectX/sin.ll
+++ b/llvm/test/CodeGen/DirectX/sin.ll
@@ -1,8 +1,6 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s -check-prefix=SM6_3
 
 ; Make sure dxil operation function calls for sin are generated for float and half.
-; CHECK:call float @dx.op.unary.f32(i32 13, float %{{.*}})
-; CHECK:call half @dx.op.unary.f16(i32 13, half %{{.*}})
 
 ; Function Attrs: noinline nounwind optnone
 define noundef float @sin_float(float noundef %a) #0 {
@@ -10,6 +8,7 @@ entry:
   %a.addr = alloca float, align 4
   store float %a, ptr %a.addr, align 4
   %0 = load float, ptr %a.addr, align 4
+  ; SM6_3: call float @dx.op.unary.f32(i32 13, float %{{.*}})
   %1 = call float @llvm.sin.f32(float %0)
   ret float %1
 }
@@ -20,6 +19,7 @@ entry:
   %a.addr = alloca half, align 2
   store half %a, ptr %a.addr, align 2
   %0 = load half, ptr %a.addr, align 2
+  ; SM6_3: call half @dx.op.unary.f16(i32 13, half %{{.*}})
   %1 = call half @llvm.sin.f16(half %0)
   ret half %1
 }
diff --git a/llvm/test/CodeGen/DirectX/sin_error.ll b/llvm/test/CodeGen/DirectX/sin_error.ll
index ece0e530315b2f..3e954f32d9bb1b 100644
--- a/llvm/test/CodeGen/DirectX/sin_error.ll
+++ b/llvm/test/CodeGen/DirectX/sin_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation sin does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload
diff --git a/llvm/test/CodeGen/DirectX/sin_sm_error.ll b/llvm/test/CodeGen/DirectX/sin_sm_error.ll
new file mode 100644
index 00000000000000..84bf54901f313a
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/sin_sm_error.ll
@@ -0,0 +1,15 @@
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %s 2>&1 | FileCheck %s
+
+; DXIL operation sin does not support half overload type in SM6.0
+; CHECK: LLVM ERROR: Invalid Overload
+
+; Function Attrs: noinline nounwind optnone
+define noundef half @sin_half(half noundef %a) #0 {
+entry:
+  %a.addr = alloca half, align 2
+  store half %a, ptr %a.addr, align 2
+  %0 = load half, ptr %a.addr, align 2
+  ; SM6_3: call half @dx.op.unary.f16(i32 13, half %{{.*}})
+  %1 = call half @llvm.sin.f16(half %0)
+  ret half %1
+}
diff --git a/llvm/test/CodeGen/DirectX/smax.ll b/llvm/test/CodeGen/DirectX/smax.ll
index 8b2406782c0938..bcda51cb0bfba6 100644
--- a/llvm/test/CodeGen/DirectX/smax.ll
+++ b/llvm/test/CodeGen/DirectX/smax.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for smax are generated for i16/i32/i64.
 
diff --git a/llvm/test/CodeGen/DirectX/smin.ll b/llvm/test/CodeGen/DirectX/smin.ll
index b2b40a1b624335..8d4884704df213 100644
--- a/llvm/test/CodeGen/DirectX/smin.ll
+++ b/llvm/test/CodeGen/DirectX/smin.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for smin are generated for i16/i32/i64.
 
diff --git a/llvm/test/CodeGen/DirectX/sqrt.ll b/llvm/test/CodeGen/DirectX/sqrt.ll
index 76a572efd20557..792fbc8d0614d3 100644
--- a/llvm/test/CodeGen/DirectX/sqrt.ll
+++ b/llvm/test/CodeGen/DirectX/sqrt.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for sqrt are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/sqrt_error.ll b/llvm/test/CodeGen/DirectX/sqrt_error.ll
index fffa2e19b80fa9..1477abc62c13a2 100644
--- a/llvm/test/CodeGen/DirectX/sqrt_error.ll
+++ b/llvm/test/CodeGen/DirectX/sqrt_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation sqrt does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/trunc.ll b/llvm/test/CodeGen/DirectX/trunc.ll
index 2072f28cef50a0..f00b737da4dbb3 100644
--- a/llvm/test/CodeGen/DirectX/trunc.ll
+++ b/llvm/test/CodeGen/DirectX/trunc.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for trunc are generated for float and half.
 
diff --git a/llvm/test/CodeGen/DirectX/trunc_error.ll b/llvm/test/CodeGen/DirectX/trunc_error.ll
index 751b0b94c280df..ccc7b1df879ee3 100644
--- a/llvm/test/CodeGen/DirectX/trunc_error.ll
+++ b/llvm/test/CodeGen/DirectX/trunc_error.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
 ; DXIL operation trunc does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
diff --git a/llvm/test/CodeGen/DirectX/umax.ll b/llvm/test/CodeGen/DirectX/umax.ll
index be0f557fc8da69..a4bd66ef0bd6c3 100644
--- a/llvm/test/CodeGen/DirectX/umax.ll
+++ b/llvm/test/CodeGen/DirectX/umax.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for umax are generated for i16/i32/i64.
 
diff --git a/llvm/test/CodeGen/DirectX/umin.ll b/llvm/test/CodeGen/DirectX/umin.ll
index 5051c711744892..a551f8ff3bfa9d 100644
--- a/llvm/test/CodeGen/DirectX/umin.ll
+++ b/llvm/test/CodeGen/DirectX/umin.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
 ; Make sure dxil operation function calls for umin are generated for i16/i32/i64.
 
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 77f0d2cb8c937f..0af90ebee5d2d5 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -23,7 +23,11 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
+
+#include <algorithm>
+#include <cstdint>
 #include <string>
+#include <vector>
 
 using namespace llvm;
 using namespace llvm::dxil;
@@ -177,7 +181,16 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   }
 
   // Get valid overload types of the Operation
-  auto OverloadTypeRecs = R->getValueAsListOfDefs("OpOverloadTypes");
+  std::vector<Record *> OverloadTypeRecs =
+      R->getValueAsListOfDefs("OpOverloadTypes");
+  // Sort records in ascending order of Shader Model version
+  std::sort(
+      OverloadTypeRecs.begin(), OverloadTypeRecs.end(),
+      [](Record *a, Record *b) {
+        return (
+            a->getValueAsDef("ShaderModel")->getValueAsInt("MajorAndMinor") <
+            b->getValueAsDef("ShaderModel")->getValueAsInt("MajorAndMinor"));
+      });
   unsigned OverloadTypeRecsSize = OverloadTypeRecs.size();
   // Populate OpOverloads with
   for (unsigned I = 0; I < OverloadTypeRecsSize; I++) {
@@ -268,20 +281,33 @@ static std::string getOverloadKindStr(const Record *R) {
 }
 /// Return a string representation of OverloadKind enum that maps to
 /// input LLVMType record
-/// \param R TableGen def record of class LLVMType
+/// \param Recs A vector of records of TableGen class type DXILShaderModel
 /// \return std::string string representation of OverloadKind
 
+// Constant value that is used to encode shader model version
+// denoting SM5.0
+
 static std::string
-getOverloadKindStrs(const SmallVector<Record *> OverloadTys) {
-  if (OverloadTys.empty()) {
-    return {};
-  }
+getOverloadKindStrs(const SmallVector<Record *> Recs) {
   std::string OverloadString = "";
-  auto Iter = OverloadTys.begin();
-  OverloadString.append(getOverloadKindStr(*Iter++));
-  for (; Iter != OverloadTys.end(); ++Iter) {
-    OverloadString.append(" | ").append(getOverloadKindStr(*Iter));
+  std::string Prefix = "";
+  OverloadString.append("{");
+  for (auto OvRec : Recs) {
+    OverloadString.append(Prefix).append("{");
+    OverloadString
+        .append(std::to_string(OvRec->getValueAsDef("ShaderModel")
+                                   ->getValueAsInt("MajorAndMinor")))
+        .append(", ");
+    auto OverloadTys = OvRec->getValueAsListOfDefs("OpOverloads");
+    auto Iter = OverloadTys.begin();
+    OverloadString.append(getOverloadKindStr(*Iter++));
+    for (; Iter != OverloadTys.end(); ++Iter) {
+      OverloadString.append(" | ").append(getOverloadKindStr(*Iter));
+    }
+    OverloadString.append("}");
+    Prefix = ", ";
   }
+  OverloadString.append("}");
   return OverloadString;
 }
 
@@ -403,6 +429,7 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
         "{\n";
 
   OS << "  static const OpCodeProperty OpCodeProps[] = {\n";
+  std::string Prefix = "";
   for (auto &Op : Ops) {
     // Consider Op.OverloadParamIndex as the overload parameter index, by
     // default
@@ -414,13 +441,14 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
     if (OLParamIdx < 0) {
       OLParamIdx = (Op.OpTypes.size() > 1) ? 1 : 0;
     }
-    OS << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
+    OS << Prefix << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
        << ", OpCodeClass::" << Op.OpClass << ", "
        << OpClassStrings.get(Op.OpClass.data()) << ", "
        << getOverloadKindStrs(Op.OpOverloads) << ", "
        << emitDXILOperationAttr(Op.OpAttributes) << ", "
        << Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
-       << Parameters.get(ParameterMap[Op.OpClass]) << " },\n";
+       << Parameters.get(ParameterMap[Op.OpClass]) << " }\n";
+       Prefix = ",";
   }
   OS << "  };\n";
 

>From 01092718e15af5c3aafd2864f003d2587aede9c4 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 5 Apr 2024 11:40:41 -0400
Subject: [PATCH 3/6] Reorganize tests to lower llvm.sin.*. This pattern of
 tests will facilitate use of same test sources to test lowering of various
 combinations of options.

---
 .../{sin_error.ll => Inputs/sin/double.ll}    |  4 --
 llvm/test/CodeGen/DirectX/Inputs/sin/float.ll |  9 ++++
 .../{sin_sm_error.ll => Inputs/sin/half.ll}   |  6 ---
 llvm/test/CodeGen/DirectX/sin.ll              | 42 +++++++++----------
 4 files changed, 29 insertions(+), 32 deletions(-)
 rename llvm/test/CodeGen/DirectX/{sin_error.ll => Inputs/sin/double.ll} (55%)
 create mode 100644 llvm/test/CodeGen/DirectX/Inputs/sin/float.ll
 rename llvm/test/CodeGen/DirectX/{sin_sm_error.ll => Inputs/sin/half.ll} (50%)

diff --git a/llvm/test/CodeGen/DirectX/sin_error.ll b/llvm/test/CodeGen/DirectX/Inputs/sin/double.ll
similarity index 55%
rename from llvm/test/CodeGen/DirectX/sin_error.ll
rename to llvm/test/CodeGen/DirectX/Inputs/sin/double.ll
index 3e954f32d9bb1b..949649e9b5b11c 100644
--- a/llvm/test/CodeGen/DirectX/sin_error.ll
+++ b/llvm/test/CodeGen/DirectX/Inputs/sin/double.ll
@@ -1,7 +1,3 @@
-; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
-
-; DXIL operation sin does not support double overload type
-; CHECK: LLVM ERROR: Invalid Overload
 
 define noundef double @sin_double(double noundef %a) #0 {
 entry:
diff --git a/llvm/test/CodeGen/DirectX/Inputs/sin/float.ll b/llvm/test/CodeGen/DirectX/Inputs/sin/float.ll
new file mode 100644
index 00000000000000..6558385e88d67b
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/Inputs/sin/float.ll
@@ -0,0 +1,9 @@
+; Function Attrs: noinline nounwind optnone
+define noundef float @sin_float(float noundef %a) #0 {
+entry:
+  %a.addr = alloca float, align 4
+  store float %a, ptr %a.addr, align 4
+  %0 = load float, ptr %a.addr, align 4
+  %1 = call float @llvm.sin.f32(float %0)
+  ret float %1
+}
diff --git a/llvm/test/CodeGen/DirectX/sin_sm_error.ll b/llvm/test/CodeGen/DirectX/Inputs/sin/half.ll
similarity index 50%
rename from llvm/test/CodeGen/DirectX/sin_sm_error.ll
rename to llvm/test/CodeGen/DirectX/Inputs/sin/half.ll
index 84bf54901f313a..39fbf3d51701d8 100644
--- a/llvm/test/CodeGen/DirectX/sin_sm_error.ll
+++ b/llvm/test/CodeGen/DirectX/Inputs/sin/half.ll
@@ -1,15 +1,9 @@
-; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %s 2>&1 | FileCheck %s
-
-; DXIL operation sin does not support half overload type in SM6.0
-; CHECK: LLVM ERROR: Invalid Overload
-
 ; Function Attrs: noinline nounwind optnone
 define noundef half @sin_half(half noundef %a) #0 {
 entry:
   %a.addr = alloca half, align 2
   store half %a, ptr %a.addr, align 2
   %0 = load half, ptr %a.addr, align 2
-  ; SM6_3: call half @dx.op.unary.f16(i32 13, half %{{.*}})
   %1 = call half @llvm.sin.f16(half %0)
   ret half %1
 }
diff --git a/llvm/test/CodeGen/DirectX/sin.ll b/llvm/test/CodeGen/DirectX/sin.ll
index 789f6a73fe6aaa..ac8ab1ec48339d 100644
--- a/llvm/test/CodeGen/DirectX/sin.ll
+++ b/llvm/test/CodeGen/DirectX/sin.ll
@@ -1,25 +1,23 @@
-; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s -check-prefix=SM6_3
+// Shader Mode 6.0
+// RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %S/Inputs/sin/half.ll 2>&1 | FileCheck %s -check-prefix=SM6_0_HALF
+// RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %S/Inputs/sin/float.ll | FileCheck %s -check-prefix=SM6_0_FLOAT
+// RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %S/inputs/sin/double.ll 2>&1 | FileCheck %s --check-prefix=SM6_0_DOUBLE
 
-; Make sure dxil operation function calls for sin are generated for float and half.
+// RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %S/Inputs/sin/half.ll | FileCheck %s -check-prefix=SM6_3_HALF
+// RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %S/Inputs/sin/float.ll | FileCheck %s -check-prefix=SM6_3_FLOAT
+// RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %S/inputs/sin/double.ll 2>&1 | FileCheck %s --check-prefix=SM6_3_DOUBLE
 
-; Function Attrs: noinline nounwind optnone
-define noundef float @sin_float(float noundef %a) #0 {
-entry:
-  %a.addr = alloca float, align 4
-  store float %a, ptr %a.addr, align 4
-  %0 = load float, ptr %a.addr, align 4
-  ; SM6_3: call float @dx.op.unary.f32(i32 13, float %{{.*}})
-  %1 = call float @llvm.sin.f32(float %0)
-  ret float %1
-}
+// Float is valid for SM6.0
+// SM6_0_FLOAT: call float @dx.op.unary.f32(i32 13, float %{{.*}})
+
+// Half is not valid for SM6.0
+// SM6_0_HALF: LLVM ERROR: Invalid Overload
+
+// Half and float are valid for SM6.2 and later
+// SM6_3_HALF: call half @dx.op.unary.f16(i32 13, half %{{.*}})
+// SM6_3_FLOAT: call float @dx.op.unary.f32(i32 13, float %{{.*}})
+
+// Double is not valid in any Shader Model version
+// SM6_0_DOUBLE: LLVM ERROR: Invalid Overload
+// SM6_3_DOUBLE: LLVM ERROR: Invalid Overload
 
-; Function Attrs: noinline nounwind optnone
-define noundef half @sin_half(half noundef %a) #0 {
-entry:
-  %a.addr = alloca half, align 2
-  store half %a, ptr %a.addr, align 2
-  %0 = load half, ptr %a.addr, align 2
-  ; SM6_3: call half @dx.op.unary.f16(i32 13, half %{{.*}})
-  %1 = call half @llvm.sin.f16(half %0)
-  ret half %1
-}

>From eea6cf8d668484e43c976a1c3d1d719f2189bdd4 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 5 Apr 2024 11:50:47 -0400
Subject: [PATCH 4/6] Delete DXILClass defs that are commented out. Need to add
 corresponding class defs when adding DXIL Ops of a new class.

---
 llvm/lib/Target/DirectX/DXIL.td     | 203 +++-------------------------
 llvm/utils/TableGen/DXILEmitter.cpp |   9 +-
 2 files changed, 24 insertions(+), 188 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 0ea0f016fd6c74..ce8a14cda94a9d 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -53,194 +53,31 @@ class DXILOpClass<list<LLVMType> OpSig> {
 // utils/hct/hctdb_instrhelp.py of DirectXShaderCompiler repo. The function
 // name has the format "dx.op.<class-name>.<return-type>", in most cases.
 
+// NOTE: The following list is not complete. Classes need to be defined as new DXIL Ops
+// are added.
 defset list<DXILOpClass> OpClasses = {
-def acceptHitAndEndSearch : DXILOpClass<[llvm_void_ty]>;
-//  def allocateNodeOutputRecords : DXILOpClass;
-def allocateRayQuery : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
-//  def annotateHandle : DXILOpClass;
-//  def annotateNodeHandle : DXILOpClass;
-//  def annotateNodeRecordHandle : DXILOpClass;
-//  def atomicBinOp : DXILOpClass;
-//  def atomicCompareExchange : DXILOpClass;
-def attributeAtVertex : DXILOpClass<[llvm_any_ty, llvm_i32_ty, llvm_i32_ty, llvm_i8_ty, llvm_i8_ty]>;
-def barrier : DXILOpClass<[llvm_void_ty, llvm_i32_ty]>;
-//  def barrierByMemoryHandle : DXILOpClass;
-def barrierByMemoryType : DXILOpClass<[llvm_void_ty, llvm_i32_ty, llvm_i32_ty]>;
-//  def barrierByNodeRecordHandle : DXILOpClass;
-def binary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>]>;
-def binaryWithCarryOrBorrow : DXILOpClass<[llvm_i32_ty, llvm_any_ty, LLVMMatchType<0>]>;
-//  def binaryWithTwoOuts : DXILOpClass;
-//  def bitcastF16toI16 : DXILOpClass;
-//  def bitcastF32toI32 : DXILOpClass;
-//  def bitcastF64toI64 : DXILOpClass;
-//  def bitcastI16toF16 : DXILOpClass;
-//  def bitcastI32toF32 : DXILOpClass;
-//  def bitcastI64toF64 : DXILOpClass;
-//  def bufferLoad : DXILOpClass;
-//  def bufferStore : DXILOpClass;
-//  def bufferUpdateCounter : DXILOpClass;
-//  def calculateLOD : DXILOpClass;
-//  def callShader : DXILOpClass;
-//  def cbufferLoad : DXILOpClass;
-//  def cbufferLoadLegacy : DXILOpClass;
-//  def checkAccessFullyMapped : DXILOpClass;
-//  def coverage : DXILOpClass;
-//  def createHandle : DXILOpClass;
-//  def createHandleForLib : DXILOpClass;
-//  def createHandleFromBinding : DXILOpClass;
-//  def createHandleFromHeap : DXILOpClass;
-//  def createNodeInputRecordHandle : DXILOpClass;
-//  def createNodeOutputHandle : DXILOpClass;
-//  def cutStream : DXILOpClass;
-//  def cycleCounterLegacy : DXILOpClass;
-//  def discard : DXILOpClass;
-//  def dispatchMesh : DXILOpClass;
-//  def dispatchRaysDimensions : DXILOpClass;
-//  def dispatchRaysIndex : DXILOpClass;
-//  def domainLocation : DXILOpClass;
-def dot2 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 5)>;
-//  def dot2AddHalf : DXILOpClass;
-def dot3 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 7)>;
-def dot4 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 9)>;
-//  def dot4AddPacked : DXILOpClass;
-//  def emitIndices : DXILOpClass;
-//  def emitStream : DXILOpClass;
-//  def emitThenCutStream : DXILOpClass;
-//  def evalCentroid : DXILOpClass;
-//  def evalSampleIndex : DXILOpClass;
-//  def evalSnapped : DXILOpClass;
-//  def finishedCrossGroupSharing : DXILOpClass;
-def flattenedThreadIdInGroup : DXILOpClass<[llvm_i32_ty]>;
-//  def geometryIndex : DXILOpClass;
-//  def getDimensions : DXILOpClass;
-//  def getInputRecordCount : DXILOpClass;
-//  def getMeshPayload : DXILOpClass;
-//  def getNodeRecordPtr : DXILOpClass;
-//  def getRemainingRecursionLevels : DXILOpClass;
-def groupId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
-//  def gsInstanceID : DXILOpClass;
-//  def hitKind : DXILOpClass;
-//  def ignoreHit : DXILOpClass;
-//  def incrementOutputCount : DXILOpClass;
-//  def indexNodeHandle : DXILOpClass;
-//  def innerCoverage : DXILOpClass;
-//  def instanceID : DXILOpClass;
-//  def instanceIndex : DXILOpClass;
-//  def isHelperLane : DXILOpClass;
-def isSpecialFloat : DXILOpClass<[llvm_i1_ty, llvm_anyfloat_ty]>;
-//  def legacyDoubleToFloat : DXILOpClass;
-//  def legacyDoubleToSInt32 : DXILOpClass;
-//  def legacyDoubleToUInt32 : DXILOpClass;
-//  def legacyF16ToF32 : DXILOpClass;
-//  def legacyF32ToF16 : DXILOpClass;
-//  def loadInput : DXILOpClass;
-//  def loadOutputControlPoint : DXILOpClass;
-//  def loadPatchConstant : DXILOpClass;
-//  def makeDouble : DXILOpClass;
-//  def minPrecXRegLoad : DXILOpClass;
-//  def minPrecXRegStore : DXILOpClass;
-//  def nodeOutputIsValid : DXILOpClass;
-//  def objectRayDirection : DXILOpClass;
-//  def objectRayOrigin : DXILOpClass;
-//  def objectToWorld : DXILOpClass;
-//  def outputComplete : DXILOpClass;
-//  def outputControlPointID : DXILOpClass;
-//  def pack4x8 : DXILOpClass;
-//  def primitiveID : DXILOpClass;
-//  def primitiveIndex : DXILOpClass;
-//  def quadOp : DXILOpClass;
-//  def quadReadLaneAt : DXILOpClass;
-//  def quadVote : DXILOpClass;
-//  def quaternary : DXILOpClass;
-//  def rawBufferLoad : DXILOpClass;
-//  def rawBufferStore : DXILOpClass;
-//  def rayFlags : DXILOpClass;
-//  def rayQuery_Abort : DXILOpClass;
-//  def rayQuery_CommitNonOpaqueTriangleHit : DXILOpClass;
-//  def rayQuery_CommitProceduralPrimitiveHit : DXILOpClass;
-//  def rayQuery_Proceed : DXILOpClass;
-//  def rayQuery_StateMatrix : DXILOpClass;
-//  def rayQuery_StateScalar : DXILOpClass;
-//  def rayQuery_StateVector : DXILOpClass;
-//  def rayQuery_TraceRayInline : DXILOpClass;
-//  def rayTCurrent : DXILOpClass;
-//  def rayTMin : DXILOpClass;
-//  def renderTargetGetSampleCount : DXILOpClass;
-//  def renderTargetGetSamplePosition : DXILOpClass;
-//  def reportHit : DXILOpClass;
-//  def sample : DXILOpClass;
-//  def sampleBias : DXILOpClass;
-//  def sampleCmp : DXILOpClass;
-//  def sampleCmpBias : DXILOpClass;
-//  def sampleCmpGrad : DXILOpClass;
-//  def sampleCmpLevel : DXILOpClass;
-//  def sampleCmpLevelZero : DXILOpClass;
-//  def sampleGrad : DXILOpClass;
-//  def sampleIndex : DXILOpClass;
-//  def sampleLevel : DXILOpClass;
-//  def setMeshOutputCounts : DXILOpClass;
-//  def splitDouble : DXILOpClass;
-//  def startInstanceLocation : DXILOpClass;
-//  def startVertexLocation : DXILOpClass;
-//  def storeOutput : DXILOpClass;
-//  def storePatchConstant : DXILOpClass;
-//  def storePrimitiveOutput : DXILOpClass;
-//  def storeVertexOutput : DXILOpClass;
-//  def tempRegLoad : DXILOpClass;
-//  def tempRegStore : DXILOpClass;
-def tertiary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
-//  def texture2DMSGetSamplePosition : DXILOpClass;
-//  def textureGather : DXILOpClass;
-//  def textureGatherCmp : DXILOpClass;
-//  def textureGatherRaw : DXILOpClass;
-//  def textureLoad : DXILOpClass;
-//  def textureStore : DXILOpClass;
-//  def textureStoreSample : DXILOpClass;
-def threadId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
-def threadIdInGroup : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
-//  def traceRay : DXILOpClass;
-def unary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>]>;
-//  def unaryBits : DXILOpClass;
-//  def unpack4x8 : DXILOpClass;
-//  def viewID : DXILOpClass;
-//  def waveActiveAllEqual : DXILOpClass;
-//  def waveActiveBallot : DXILOpClass;
-//  def waveActiveBit : DXILOpClass;
-//  def waveActiveOp : DXILOpClass;
-//  def waveAllOp : DXILOpClass;
-//  def waveAllTrue : DXILOpClass;
-//  def waveAnyTrue : DXILOpClass;
-//  def waveGetLaneCount : DXILOpClass;
-//  def waveGetLaneIndex : DXILOpClass;
-//  def waveIsFirstLane : DXILOpClass;
-//  def waveMatch : DXILOpClass;
-//  def waveMatrix_Accumulate : DXILOpClass;
-//  def waveMatrix_Annotate : DXILOpClass;
-//  def waveMatrix_Depth : DXILOpClass;
-//  def waveMatrix_Fill : DXILOpClass;
-//  def waveMatrix_LoadGroupShared : DXILOpClass;
-//  def waveMatrix_LoadRawBuf : DXILOpClass;
-//  def waveMatrix_Multiply : DXILOpClass;
-//  def waveMatrix_ScalarOp : DXILOpClass;
-//  def waveMatrix_StoreGroupShared : DXILOpClass;
-//  def waveMatrix_StoreRawBuf : DXILOpClass;
-//  def waveMultiPrefixBitCount : DXILOpClass;
-//  def waveMultiPrefixOp : DXILOpClass;
-//  def wavePrefixOp : DXILOpClass;
-//  def waveReadLaneAt : DXILOpClass;
-//  def waveReadLaneFirst : DXILOpClass;
-//  def worldRayDirection : DXILOpClass;
-//  def worldRayOrigin : DXILOpClass;
-//  def worldToObject : DXILOpClass;
-//  def writeSamplerFeedback : DXILOpClass;
-//  def writeSamplerFeedbackBias : DXILOpClass;
-//  def writeSamplerFeedbackGrad : DXILOpClass;
-//  def writeSamplerFeedbackLevel: DXILOpClass;
+  def acceptHitAndEndSearch : DXILOpClass<[llvm_void_ty]>;
+  def allocateRayQuery : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+  def attributeAtVertex : DXILOpClass<[llvm_any_ty, llvm_i32_ty, llvm_i32_ty, llvm_i8_ty, llvm_i8_ty]>;
+  def barrier : DXILOpClass<[llvm_void_ty, llvm_i32_ty]>;
+  def barrierByMemoryType : DXILOpClass<[llvm_void_ty, llvm_i32_ty, llvm_i32_ty]>;
+  def binary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>]>;
+  def binaryWithCarryOrBorrow : DXILOpClass<[llvm_i32_ty, llvm_any_ty, LLVMMatchType<0>]>;
+  def dot2 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 5)>;
+  def dot3 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 7)>;
+  def dot4 : DXILOpClass<!listsplat(llvm_anyfloat_ty, 9)>;
+  def flattenedThreadIdInGroup : DXILOpClass<[llvm_i32_ty]>;
+  def groupId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+  def isSpecialFloat : DXILOpClass<[llvm_i1_ty, llvm_anyfloat_ty]>;
+  def tertiary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
+  def threadId : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+  def threadIdInGroup : DXILOpClass<[llvm_i32_ty, llvm_i32_ty]>;
+  def unary : DXILOpClass<[llvm_any_ty, LLVMMatchType<0>]>;
 
   // This is a sentinel definition. Hence placed at the end of the list
   // and not as part of the above alphabetically sorted valid definitions.
   // Additionally it is capitalized unlike all the others.
-def UnknownOpClass: DXILOpClass<[]>;
+  def UnknownOpClass: DXILOpClass<[]>;
 }
 
 // Abstraction DXIL Operation to LLVM intrinsic
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 0af90ebee5d2d5..7c9f3da3e157d8 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -287,8 +287,7 @@ static std::string getOverloadKindStr(const Record *R) {
 // Constant value that is used to encode shader model version
 // denoting SM5.0
 
-static std::string
-getOverloadKindStrs(const SmallVector<Record *> Recs) {
+static std::string getOverloadKindStrs(const SmallVector<Record *> Recs) {
   std::string OverloadString = "";
   std::string Prefix = "";
   OverloadString.append("{");
@@ -441,14 +440,14 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
     if (OLParamIdx < 0) {
       OLParamIdx = (Op.OpTypes.size() > 1) ? 1 : 0;
     }
-    OS << Prefix << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
-       << ", OpCodeClass::" << Op.OpClass << ", "
+    OS << Prefix << "  { dxil::OpCode::" << Op.OpName << ", "
+       << OpStrings.get(Op.OpName) << ", OpCodeClass::" << Op.OpClass << ", "
        << OpClassStrings.get(Op.OpClass.data()) << ", "
        << getOverloadKindStrs(Op.OpOverloads) << ", "
        << emitDXILOperationAttr(Op.OpAttributes) << ", "
        << Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
        << Parameters.get(ParameterMap[Op.OpClass]) << " }\n";
-       Prefix = ",";
+    Prefix = ",";
   }
   OS << "  };\n";
 

>From 24b497b75668560d990c57203beb1c16e074e910 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 5 Apr 2024 19:50:35 -0400
Subject: [PATCH 5/6] Define computation of the value computed using major and
 minor version Shader Model in a single place to be used both by DXILEmitter
 and by the lowering pass.

---
 llvm/include/llvm/Support/DXILABI.h        |  5 +++
 llvm/lib/Target/DirectX/DXIL.td            |  5 +--
 llvm/lib/Target/DirectX/DXILOpLowering.cpp |  6 ++--
 llvm/utils/TableGen/DXILEmitter.cpp        | 36 +++++++++++++++-------
 4 files changed, 36 insertions(+), 16 deletions(-)

diff --git a/llvm/include/llvm/Support/DXILABI.h b/llvm/include/llvm/Support/DXILABI.h
index da4bea8fc46e3a..a75a85c6f1768c 100644
--- a/llvm/include/llvm/Support/DXILABI.h
+++ b/llvm/include/llvm/Support/DXILABI.h
@@ -90,4 +90,9 @@ enum class ElementType : uint32_t {
 } // namespace dxil
 } // namespace llvm
 
+// Generate a unique value for given Major, Minor pair of Shader Model
+// version. Allows for 100 minor versions for a given major version number.
+// To be used uniformly by DXILEmitter backend as well as DXIL Lowering pass.
+#define COMPUTE_SM_VERSION_VALUE(MAJ, MIN) ((MAJ * 100) + MIN)
+
 #endif // LLVM_SUPPORT_DXILABI_H
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index ce8a14cda94a9d..4e2042ccf60866 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -16,7 +16,8 @@ include "llvm/IR/Intrinsics.td"
 // Abstract class to demarcate minimum Shader model version required
 // to support DXIL Op
 class DXILShaderModel<int major, int minor> {
-  int MajorAndMinor = !add(!mul(major, 10), minor);
+  int Major = major;
+  int Minor = minor;
 }
 
 // Valid minimum Shader model version records
@@ -25,7 +26,7 @@ class DXILShaderModel<int major, int minor> {
 foreach i = 0...9 in {
   def SM6_#i : DXILShaderModel<6, i>;
 }
-// Shader Mode 7.x - for now 7.0 is defined. Extend as needed
+// Shader Model 7.x - for now 7.0 is defined. Extend as needed
 foreach i = 0 in {
   def SM7_#i : DXILShaderModel<7, i>;
 }
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index c3217d51ece1b9..7b09f0d24e7a69 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -73,7 +73,7 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
   return NewOperands;
 }
 
-static uint32_t getShaderModelVer(Module &M) {
+static uint32_t getModuleShaderModelVersion(Module &M) {
   std::string TTStr = M.getTargetTriple();
   std::string Error;
   auto Target = TargetRegistry::lookupTarget(TTStr, Error);
@@ -85,13 +85,13 @@ static uint32_t getShaderModelVer(Module &M) {
   auto Major = Triple(TTStr).getOSVersion().getMajor();
   auto MinorOrErr = Triple(TTStr).getOSVersion().getMinor();
   uint32_t Minor = MinorOrErr.has_value() ? *MinorOrErr : 0;
-  return ((Major * 10) + Minor);
+  return COMPUTE_SM_VERSION_VALUE(Major, Minor);
 }
 
 static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
   IRBuilder<> B(M.getContext());
   DXILOpBuilder DXILB(M, B);
-  uint32_t SMVer = getShaderModelVer(M);
+  uint32_t SMVer = getModuleShaderModelVersion(M);
   Type *OverloadTy = DXILB.getOverloadTy(DXILOp, SMVer, F.getFunctionType());
   for (User *U : make_early_inc_range(F.users())) {
     CallInst *CI = dyn_cast<CallInst>(U);
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 7c9f3da3e157d8..0210c0b9efbbb2 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -184,13 +184,20 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   std::vector<Record *> OverloadTypeRecs =
       R->getValueAsListOfDefs("OpOverloadTypes");
   // Sort records in ascending order of Shader Model version
-  std::sort(
-      OverloadTypeRecs.begin(), OverloadTypeRecs.end(),
-      [](Record *a, Record *b) {
-        return (
-            a->getValueAsDef("ShaderModel")->getValueAsInt("MajorAndMinor") <
-            b->getValueAsDef("ShaderModel")->getValueAsInt("MajorAndMinor"));
-      });
+  std::sort(OverloadTypeRecs.begin(), OverloadTypeRecs.end(),
+            [](Record *RecA, Record *RecB) {
+              uint16_t RecAMaj =
+                  RecA->getValueAsDef("ShaderModel")->getValueAsInt("Major");
+              uint16_t RecAMin =
+                  RecA->getValueAsDef("ShaderModel")->getValueAsInt("Minor");
+              uint16_t RecBMaj =
+                  RecB->getValueAsDef("ShaderModel")->getValueAsInt("Major");
+              uint16_t RecBMin =
+                  RecB->getValueAsDef("ShaderModel")->getValueAsInt("Minor");
+
+              return (COMPUTE_SM_VERSION_VALUE(RecAMaj, RecAMin) <
+                      COMPUTE_SM_VERSION_VALUE(RecBMaj, RecBMin));
+            });
   unsigned OverloadTypeRecsSize = OverloadTypeRecs.size();
   // Populate OpOverloads with
   for (unsigned I = 0; I < OverloadTypeRecsSize; I++) {
@@ -255,6 +262,11 @@ static std::string getParameterKindStr(ParameterKind Kind) {
   llvm_unreachable("Unknown llvm::dxil::ParameterKind enum");
 }
 
+/// Return a string representation of OverloadKind enum that maps to
+/// input LLVMType record
+/// \param R TableGen def record of class LLVMType
+/// \return std::string string representation of OverloadKind
+
 static std::string getOverloadKindStr(const Record *R) {
   Record *VTRec = R->getValueAsDef("VT");
   switch (getValueType(VTRec)) {
@@ -293,10 +305,12 @@ static std::string getOverloadKindStrs(const SmallVector<Record *> Recs) {
   OverloadString.append("{");
   for (auto OvRec : Recs) {
     OverloadString.append(Prefix).append("{");
-    OverloadString
-        .append(std::to_string(OvRec->getValueAsDef("ShaderModel")
-                                   ->getValueAsInt("MajorAndMinor")))
-        .append(", ");
+    uint16_t RecAMaj =
+        OvRec->getValueAsDef("ShaderModel")->getValueAsInt("Major");
+    uint16_t RecAMin =
+        OvRec->getValueAsDef("ShaderModel")->getValueAsInt("Minor");
+    uint16_t RecMajMin = COMPUTE_SM_VERSION_VALUE(RecAMaj, RecAMin);
+    OverloadString.append(std::to_string(RecMajMin)).append(", ");
     auto OverloadTys = OvRec->getValueAsListOfDefs("OpOverloads");
     auto Iter = OverloadTys.begin();
     OverloadString.append(getOverloadKindStr(*Iter++));

>From aee3f5bc179eddc4724a5dbdbd1ef57451143115 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 10 Apr 2024 16:36:51 -0400
Subject: [PATCH 6/6] Incorporate PR review feedback  - Use VersionTriple to
 deal with Shader Model version.  - Undo sin test reorganization.

---
 llvm/include/llvm/Support/DXILABI.h           |  8 +--
 llvm/lib/Target/DirectX/DXIL.td               | 16 ++----
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp     | 47 +++++++++--------
 llvm/lib/Target/DirectX/DXILOpBuilder.h       |  7 +--
 llvm/lib/Target/DirectX/DXILOpLowering.cpp    | 10 ++--
 .../test/CodeGen/DirectX/Inputs/sin/double.ll | 10 ----
 llvm/test/CodeGen/DirectX/sin.ll              | 23 ---------
 llvm/test/CodeGen/DirectX/sin_error.ll        | 16 ++++++
 .../{Inputs/sin/float.ll => sin_sm_60.ll}     |  5 ++
 .../sin/half.ll => sin_sm_60_error.ll}        |  5 ++
 llvm/test/CodeGen/DirectX/sin_sm_62.ll        | 25 +++++++++
 llvm/test/CodeGen/DirectX/sin_sm_62_error.ll  | 16 ++++++
 llvm/utils/TableGen/DXILEmitter.cpp           | 51 +++++++------------
 13 files changed, 124 insertions(+), 115 deletions(-)
 delete mode 100644 llvm/test/CodeGen/DirectX/Inputs/sin/double.ll
 delete mode 100644 llvm/test/CodeGen/DirectX/sin.ll
 create mode 100644 llvm/test/CodeGen/DirectX/sin_error.ll
 rename llvm/test/CodeGen/DirectX/{Inputs/sin/float.ll => sin_sm_60.ll} (57%)
 rename llvm/test/CodeGen/DirectX/{Inputs/sin/half.ll => sin_sm_60_error.ll} (57%)
 create mode 100644 llvm/test/CodeGen/DirectX/sin_sm_62.ll
 create mode 100644 llvm/test/CodeGen/DirectX/sin_sm_62_error.ll

diff --git a/llvm/include/llvm/Support/DXILABI.h b/llvm/include/llvm/Support/DXILABI.h
index a75a85c6f1768c..38cdc69a36fb29 100644
--- a/llvm/include/llvm/Support/DXILABI.h
+++ b/llvm/include/llvm/Support/DXILABI.h
@@ -90,9 +90,9 @@ enum class ElementType : uint32_t {
 } // namespace dxil
 } // namespace llvm
 
-// Generate a unique value for given Major, Minor pair of Shader Model
-// version. Allows for 100 minor versions for a given major version number.
-// To be used uniformly by DXILEmitter backend as well as DXIL Lowering pass.
-#define COMPUTE_SM_VERSION_VALUE(MAJ, MIN) ((MAJ * 100) + MIN)
+struct DXILShaderModel {
+  unsigned Major = 0;
+  unsigned Minor = 0;
+};
 
 #endif // LLVM_SUPPORT_DXILABI_H
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 4e2042ccf60866..3f1e842858fb02 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -22,14 +22,10 @@ class DXILShaderModel<int major, int minor> {
 
 // Valid minimum Shader model version records
 
-// Shader Mode 6.x
-foreach i = 0...9 in {
+// Shader Model 6.0 - 6.8
+foreach i = 0...8 in {
   def SM6_#i : DXILShaderModel<6, i>;
 }
-// Shader Model 7.x - for now 7.0 is defined. Extend as needed
-foreach i = 0 in {
-  def SM7_#i : DXILShaderModel<7, i>;
-}
 
 // Abstraction of class mapping valid DXIL Op overloads the minimum
 // version of Shader Model they are supported
@@ -109,14 +105,13 @@ let OpClass = isSpecialFloat in {
                            "Determines if the specified value is infinite.">;
 }
 
-// Unary Class
 let OpClass = unary in {
   def Abs : DXILOpMapping<6, int_fabs, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
                           "Returns the absolute value of the input.">;
 
   def Cos  : DXILOpMapping<12, int_cos, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
                           "Returns cosine(theta) for theta in radians.">;
-  def Sin  : DXILOpMapping<13, int_sin, [DXILOpOverload<SM6_3, [llvm_half_ty, llvm_float_ty]>,
+  def Sin  : DXILOpMapping<13, int_sin, [DXILOpOverload<SM6_2, [llvm_half_ty, llvm_float_ty]>,
                                          DXILOpOverload<SM6_0, [llvm_float_ty]>],
                            "Returns sine(theta) for theta in radians.">;
   def Exp2 : DXILOpMapping<21, int_exp2, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
@@ -146,7 +141,6 @@ let OpClass = unary in {
                             "Returns the specified value with its bits reversed.">;
 }
 
-// Binary Class
 let OpClass = binary in {
 // Float overloads
   def FMax : DXILOpMapping<35, int_maxnum, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
@@ -164,10 +158,7 @@ let OpClass = binary in {
                            "Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
 }
 
-// Tertiary Class
 let OpClass = tertiary in {
-// Float overloads
-//   let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty, llvm_double_ty] in {
   def FMad : DXILOpMapping<46, int_fmuladd, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty, llvm_double_ty]>],
                             "Floating point arithmetic multiply/add operation."
                             " fmad(m,a,b) = m * a + b.">;
@@ -181,7 +172,6 @@ def UMad : DXILOpMapping<49, int_dx_umad, [DXILOpOverload<SM6_0, [llvm_i16_ty, l
 }
 
 // Dot Operations
-// let OpOverloadTypes =  [llvm_half_ty, llvm_float_ty] in
 let OpClass = dot2 in
   def Dot2 : DXILOpMapping<54, int_dx_dot2, [DXILOpOverload<SM6_0, [llvm_half_ty, llvm_float_ty]>],
                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] +"
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 1a4f8c709c0fd2..2d03df09ff7d2e 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -15,6 +15,9 @@
 #include "llvm/IR/Module.h"
 #include "llvm/Support/DXILABI.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/VersionTuple.h"
+#include <algorithm>
+#include <cassert>
 #include <string>
 
 using namespace llvm;
@@ -125,7 +128,7 @@ static std::string getTypeName(OverloadKind Kind, Type *Ty) {
 }
 
 struct OpSMOverloadProp {
-  uint16_t ShaderModelVer;
+  DXILShaderModel ShaderModelVer;
   uint16_t ValidTys;
 };
 
@@ -256,35 +259,35 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
 }
 
 static uint16_t getValidOverloadMask(const OpCodeProperty *Prop,
-                                     uint32_t SMVer) {
+                                     VersionTuple SMVer) {
   uint16_t ValidTyMask = 0;
   // std::vector Prop->OverloadProp is in ascending order of SM Version
   // Overloads of highest SM version that is not greater than SMVer
   // are the ones that are valid for SMVer.
-  for (auto OL : Prop->OverloadProp) {
-    if (OL.ShaderModelVer <= SMVer) {
-      ValidTyMask = OL.ValidTys;
-    } else {
-      break;
-    }
-  }
+
+  // Get the lower bound value iterator of SMVer
+  auto LaterSM = std::lower_bound(
+      Prop->OverloadProp.begin(), Prop->OverloadProp.end(), SMVer,
+      [](const OpSMOverloadProp OL, VersionTuple VerTup) {
+        return (VersionTuple(OL.ShaderModelVer.Major,
+                             OL.ShaderModelVer.Minor) <= VerTup);
+      });
+  // Valid overloads are of the version prior to the lower bound
+  ValidTyMask = (--LaterSM)->ValidTys;
+  assert(ValidTyMask != 0 && "No valid overload types found");
   return ValidTyMask;
 }
 
 namespace llvm {
 namespace dxil {
 
-CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, uint32_t SMVer,
-                                          Type *ReturnTy, Type *OverloadTy,
+CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode,
+                                          VersionTuple &SMVer, Type *ReturnTy,
+                                          Type *OverloadTy,
                                           SmallVector<Value *> Args) {
   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
   uint16_t ValidTyMask = getValidOverloadMask(Prop, SMVer);
 
-  if (ValidTyMask == 0) {
-    report_fatal_error(StringRef(std::to_string(SMVer).append(
-                           ": Unhandled Shader Model Version")),
-                       /*gen_crash_diag*/ false);
-  }
   OverloadKind Kind = getOverloadKind(OverloadTy);
   if ((ValidTyMask & (uint16_t)Kind) == 0) {
     report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);
@@ -304,7 +307,7 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, uint32_t SMVer,
   return B.CreateCall(DXILFn, Args);
 }
 
-Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, uint32_t SMVer,
+Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, VersionTuple &SMVer,
                                    FunctionType *FT) {
 
   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
@@ -313,11 +316,6 @@ Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, uint32_t SMVer,
   if (Prop->OverloadParamIndex < 0) {
     auto &Ctx = FT->getContext();
     uint16_t ValidTyMask = getValidOverloadMask(Prop, SMVer);
-    if (ValidTyMask == 0) {
-      report_fatal_error(StringRef(std::to_string(SMVer).append(
-                             ": Unhandled Shader Model Version")),
-                         /*gen_crash_diag*/ false);
-    }
 
     switch (ValidTyMask) {
     case OverloadKind::VOID:
@@ -344,14 +342,15 @@ Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, uint32_t SMVer,
     }
   }
 
-  // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
+  // Consider FT->getReturnType() as default overload type, unless
+  // Prop->OverloadParamIndex != 0.
   Type *OverloadType = FT->getReturnType();
   if (Prop->OverloadParamIndex != 0) {
     // Skip Return Type.
     OverloadType = FT->getParamType(Prop->OverloadParamIndex - 1);
   }
 
-  auto ParamKinds = getOpCodeParameterKind(*Prop);
+  const auto *ParamKinds = getOpCodeParameterKind(*Prop);
   auto Kind = ParamKinds[Prop->OverloadParamIndex];
   // For ResRet and CBufferRet, OverloadTy is in field of StructType.
   if (Kind == ParameterKind::CBufferRet ||
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 1e15286c810a8a..43da15cec03923 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -14,7 +14,7 @@
 
 #include "DXILConstants.h"
 #include "llvm/ADT/SmallVector.h"
-#include <cstdint>
+#include "llvm/Support/VersionTuple.h"
 
 namespace llvm {
 class Module;
@@ -38,10 +38,11 @@ class DXILOpBuilder {
   /// \param ReturnTy Return type of the DXIL Op call constructed
   /// \param OverloadTy Overload type of the DXIL Op call constructed
   /// \return DXIL Op call constructed
-  CallInst *createDXILOpCall(dxil::OpCode OpCode, uint32_t SMVer,
+  CallInst *createDXILOpCall(dxil::OpCode OpCode, VersionTuple &SMVer,
                              Type *ReturnTy, Type *OverloadTy,
                              SmallVector<Value *> Args);
-  Type *getOverloadTy(dxil::OpCode OpCode, uint32_t SMVer, FunctionType *FT);
+  Type *getOverloadTy(dxil::OpCode OpCode, VersionTuple &SMVer,
+                      FunctionType *FT);
   static const char *getOpCodeName(dxil::OpCode DXILOp);
 
 private:
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 7b09f0d24e7a69..8b989cf9f072ed 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -25,6 +25,7 @@
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/VersionTuple.h"
 
 #define DEBUG_TYPE "dxil-op-lower"
 
@@ -73,7 +74,7 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
   return NewOperands;
 }
 
-static uint32_t getModuleShaderModelVersion(Module &M) {
+static VersionTuple getModuleShaderModelVersion(Module &M) {
   std::string TTStr = M.getTargetTriple();
   std::string Error;
   auto Target = TargetRegistry::lookupTarget(TTStr, Error);
@@ -82,16 +83,13 @@ static uint32_t getModuleShaderModelVersion(Module &M) {
       report_fatal_error(StringRef(Error), /*gen_crash_diag*/ false);
     }
   }
-  auto Major = Triple(TTStr).getOSVersion().getMajor();
-  auto MinorOrErr = Triple(TTStr).getOSVersion().getMinor();
-  uint32_t Minor = MinorOrErr.has_value() ? *MinorOrErr : 0;
-  return COMPUTE_SM_VERSION_VALUE(Major, Minor);
+  return Triple(TTStr).getOSVersion();
 }
 
 static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
   IRBuilder<> B(M.getContext());
   DXILOpBuilder DXILB(M, B);
-  uint32_t SMVer = getModuleShaderModelVersion(M);
+  VersionTuple SMVer = getModuleShaderModelVersion(M);
   Type *OverloadTy = DXILB.getOverloadTy(DXILOp, SMVer, F.getFunctionType());
   for (User *U : make_early_inc_range(F.users())) {
     CallInst *CI = dyn_cast<CallInst>(U);
diff --git a/llvm/test/CodeGen/DirectX/Inputs/sin/double.ll b/llvm/test/CodeGen/DirectX/Inputs/sin/double.ll
deleted file mode 100644
index 949649e9b5b11c..00000000000000
--- a/llvm/test/CodeGen/DirectX/Inputs/sin/double.ll
+++ /dev/null
@@ -1,10 +0,0 @@
-
-define noundef double @sin_double(double noundef %a) #0 {
-entry:
-  %a.addr = alloca double, align 8
-  store double %a, ptr %a.addr, align 8
-  %0 = load double, ptr %a.addr, align 8
-  %1 = call double @llvm.sin.f64(double %0)
-  ret double %1
-}
-
diff --git a/llvm/test/CodeGen/DirectX/sin.ll b/llvm/test/CodeGen/DirectX/sin.ll
deleted file mode 100644
index ac8ab1ec48339d..00000000000000
--- a/llvm/test/CodeGen/DirectX/sin.ll
+++ /dev/null
@@ -1,23 +0,0 @@
-// Shader Mode 6.0
-// RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %S/Inputs/sin/half.ll 2>&1 | FileCheck %s -check-prefix=SM6_0_HALF
-// RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %S/Inputs/sin/float.ll | FileCheck %s -check-prefix=SM6_0_FLOAT
-// RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %S/inputs/sin/double.ll 2>&1 | FileCheck %s --check-prefix=SM6_0_DOUBLE
-
-// RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %S/Inputs/sin/half.ll | FileCheck %s -check-prefix=SM6_3_HALF
-// RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %S/Inputs/sin/float.ll | FileCheck %s -check-prefix=SM6_3_FLOAT
-// RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %S/inputs/sin/double.ll 2>&1 | FileCheck %s --check-prefix=SM6_3_DOUBLE
-
-// Float is valid for SM6.0
-// SM6_0_FLOAT: call float @dx.op.unary.f32(i32 13, float %{{.*}})
-
-// Half is not valid for SM6.0
-// SM6_0_HALF: LLVM ERROR: Invalid Overload
-
-// Half and float are valid for SM6.2 and later
-// SM6_3_HALF: call half @dx.op.unary.f16(i32 13, half %{{.*}})
-// SM6_3_FLOAT: call float @dx.op.unary.f32(i32 13, float %{{.*}})
-
-// Double is not valid in any Shader Model version
-// SM6_0_DOUBLE: LLVM ERROR: Invalid Overload
-// SM6_3_DOUBLE: LLVM ERROR: Invalid Overload
-
diff --git a/llvm/test/CodeGen/DirectX/sin_error.ll b/llvm/test/CodeGen/DirectX/sin_error.ll
new file mode 100644
index 00000000000000..2e4c25058ca961
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/sin_error.ll
@@ -0,0 +1,16 @@
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %s 2>&1 | FileCheck %s --check-prefix=SM6_0_DOUBLE
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s --check-prefix=SM6_3_DOUBLE
+
+; Double is not valid in any Shader Model version
+; SM6_0_DOUBLE: LLVM ERROR: Invalid Overload
+; SM6_3_DOUBLE: LLVM ERROR: Invalid Overload
+
+define noundef double @sin_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %1 = call double @llvm.sin.f64(double %0)
+  ret double %1
+}
+
diff --git a/llvm/test/CodeGen/DirectX/Inputs/sin/float.ll b/llvm/test/CodeGen/DirectX/sin_sm_60.ll
similarity index 57%
rename from llvm/test/CodeGen/DirectX/Inputs/sin/float.ll
rename to llvm/test/CodeGen/DirectX/sin_sm_60.ll
index 6558385e88d67b..09dbfce453af69 100644
--- a/llvm/test/CodeGen/DirectX/Inputs/sin/float.ll
+++ b/llvm/test/CodeGen/DirectX/sin_sm_60.ll
@@ -1,3 +1,8 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %s | FileCheck %s -check-prefix=SM6_0_FLOAT
+
+; Float is valid for SM6.0
+; SM6_0_FLOAT: call float @dx.op.unary.f32(i32 13, float %{{.*}})
+
 ; Function Attrs: noinline nounwind optnone
 define noundef float @sin_float(float noundef %a) #0 {
 entry:
diff --git a/llvm/test/CodeGen/DirectX/Inputs/sin/half.ll b/llvm/test/CodeGen/DirectX/sin_sm_60_error.ll
similarity index 57%
rename from llvm/test/CodeGen/DirectX/Inputs/sin/half.ll
rename to llvm/test/CodeGen/DirectX/sin_sm_60_error.ll
index 39fbf3d51701d8..936b8939b31559 100644
--- a/llvm/test/CodeGen/DirectX/Inputs/sin/half.ll
+++ b/llvm/test/CodeGen/DirectX/sin_sm_60_error.ll
@@ -1,3 +1,8 @@
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %s 2>&1 | FileCheck %s -check-prefix=SM6_0_HALF
+
+; Half is not valid for SM6.0
+; SM6_0_HALF: LLVM ERROR: Invalid Overload
+
 ; Function Attrs: noinline nounwind optnone
 define noundef half @sin_half(half noundef %a) #0 {
 entry:
diff --git a/llvm/test/CodeGen/DirectX/sin_sm_62.ll b/llvm/test/CodeGen/DirectX/sin_sm_62.ll
new file mode 100644
index 00000000000000..c484744404459d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/sin_sm_62.ll
@@ -0,0 +1,25 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s -check-prefix=SM6_3
+; Half and float are valid for SM6.2 and later
+; SM6_3: call half @dx.op.unary.f16(i32 13, half %{{.*}})
+; SM6_3: call float @dx.op.unary.f32(i32 13, float %{{.*}})
+
+; Function Attrs: noinline nounwind optnone
+define noundef half @sin_half(half noundef %a) #0 {
+entry:
+  %a.addr = alloca half, align 2
+  store half %a, ptr %a.addr, align 2
+  %0 = load half, ptr %a.addr, align 2
+  %1 = call half @llvm.sin.f16(half %0)
+  ret half %1
+}
+
+; Function Attrs: noinline nounwind optnone
+define noundef float @sin_float(float noundef %a) #0 {
+entry:
+  %a.addr = alloca float, align 4
+  store float %a, ptr %a.addr, align 4
+  %0 = load float, ptr %a.addr, align 4
+  %1 = call float @llvm.sin.f32(float %0)
+  ret float %1
+}
+
diff --git a/llvm/test/CodeGen/DirectX/sin_sm_62_error.ll b/llvm/test/CodeGen/DirectX/sin_sm_62_error.ll
new file mode 100644
index 00000000000000..2e4c25058ca961
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/sin_sm_62_error.ll
@@ -0,0 +1,16 @@
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library %s 2>&1 | FileCheck %s --check-prefix=SM6_0_DOUBLE
+; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s --check-prefix=SM6_3_DOUBLE
+
+; Double is not valid in any Shader Model version
+; SM6_0_DOUBLE: LLVM ERROR: Invalid Overload
+; SM6_3_DOUBLE: LLVM ERROR: Invalid Overload
+
+define noundef double @sin_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %1 = call double @llvm.sin.f64(double %0)
+  ret double %1
+}
+
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 0210c0b9efbbb2..1f68edf622d702 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -17,15 +17,14 @@
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
-#include "llvm/ADT/StringSwitch.h"
 #include "llvm/CodeGenTypes/MachineValueType.h"
 #include "llvm/Support/DXILABI.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/VersionTuple.h"
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
 
 #include <algorithm>
-#include <cstdint>
 #include <string>
 #include <vector>
 
@@ -34,11 +33,6 @@ using namespace llvm::dxil;
 
 namespace {
 
-struct DXILShaderModel {
-  int Major = 0;
-  int Minor = 0;
-};
-
 struct DXILOperationDesc {
   std::string OpName; // name of DXIL operation
   int OpCode;         // ID of DXIL operation
@@ -101,9 +95,8 @@ static ParameterKind getParameterKind(const Record *R) {
   case MVT::Any:
     return ParameterKind::Overload;
   default:
-    report_fatal_error(
-        "Support for specified parameter type not yet implemented",
-        /*gen_crash_diag*/ false);
+    llvm_unreachable(
+        "Support for specified parameter type not yet implemented");
   }
 }
 
@@ -147,11 +140,8 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
             break;
           }
         }
-        if (!KnownType) {
-          report_fatal_error("Specification of multiple differing overload "
-                             "parameter types not yet supported",
-                             /*gen_crash_diag*/ false);
-        }
+        assert(KnownType && "Specification of multiple differing overload "
+                            "parameter types not yet supported");
       } else {
         OverloadParamIndices.push_back(I);
       }
@@ -174,9 +164,8 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   // Set the index of the overload parameter, if any.
   OverloadParamIndex = -1; // default; indicating none
   if (!OverloadParamIndices.empty()) {
-    if (OverloadParamIndices.size() > 1)
-      report_fatal_error("Multiple overload type specification not supported",
-                         /*gen_crash_diag*/ false);
+    assert(OverloadParamIndices.size() == 1 &&
+           "Multiple overload type specification not supported");
     OverloadParamIndex = OverloadParamIndices[0];
   }
 
@@ -186,17 +175,17 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   // Sort records in ascending order of Shader Model version
   std::sort(OverloadTypeRecs.begin(), OverloadTypeRecs.end(),
             [](Record *RecA, Record *RecB) {
-              uint16_t RecAMaj =
+              unsigned RecAMaj =
                   RecA->getValueAsDef("ShaderModel")->getValueAsInt("Major");
-              uint16_t RecAMin =
+              unsigned RecAMin =
                   RecA->getValueAsDef("ShaderModel")->getValueAsInt("Minor");
-              uint16_t RecBMaj =
+              unsigned RecBMaj =
                   RecB->getValueAsDef("ShaderModel")->getValueAsInt("Major");
-              uint16_t RecBMin =
+              unsigned RecBMin =
                   RecB->getValueAsDef("ShaderModel")->getValueAsInt("Minor");
 
-              return (COMPUTE_SM_VERSION_VALUE(RecAMaj, RecAMin) <
-                      COMPUTE_SM_VERSION_VALUE(RecBMaj, RecBMin));
+              return (VersionTuple(RecAMaj, RecAMin) <
+                      VersionTuple(RecBMaj, RecBMin));
             });
   unsigned OverloadTypeRecsSize = OverloadTypeRecs.size();
   // Populate OpOverloads with
@@ -295,22 +284,20 @@ static std::string getOverloadKindStr(const Record *R) {
 /// input LLVMType record
 /// \param Recs A vector of records of TableGen class type DXILShaderModel
 /// \return std::string string representation of OverloadKind
-
-// Constant value that is used to encode shader model version
-// denoting SM5.0
-
 static std::string getOverloadKindStrs(const SmallVector<Record *> Recs) {
   std::string OverloadString = "";
   std::string Prefix = "";
   OverloadString.append("{");
   for (auto OvRec : Recs) {
     OverloadString.append(Prefix).append("{");
-    uint16_t RecAMaj =
+    unsigned RecAMaj =
         OvRec->getValueAsDef("ShaderModel")->getValueAsInt("Major");
-    uint16_t RecAMin =
+    unsigned RecAMin =
         OvRec->getValueAsDef("ShaderModel")->getValueAsInt("Minor");
-    uint16_t RecMajMin = COMPUTE_SM_VERSION_VALUE(RecAMaj, RecAMin);
-    OverloadString.append(std::to_string(RecMajMin)).append(", ");
+    OverloadString.append("{")
+        .append(std::to_string(RecAMaj))
+        .append(", ")
+        .append(std::to_string(RecAMin).append("}, "));
     auto OverloadTys = OvRec->getValueAsListOfDefs("OpOverloads");
     auto Iter = OverloadTys.begin();
     OverloadString.append(getOverloadKindStr(*Iter++));



More information about the llvm-commits mailing list