[llvm] [DXIL] Add constraint specification and backend implementation of DXIL Ops (PR #97593)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 8 09:46:29 PDT 2024


================
@@ -212,154 +242,576 @@ defset list<DXILOpClass> OpClasses = {
   def UnknownOpClass: DXILOpClass;
 }
 
-// Several of the overloaded DXIL Operations support for data types
-// that are a subset of the overloaded LLVM intrinsics that they map to.
-// For e.g., llvm.sin.* intrinsic operates on any floating-point type and
-// maps for lowering to DXIL Op Sin. However, valid overloads of DXIL Sin
-// operation overloads are half (f16) and float (f32) only.
-//
-// The following abstracts overload types specific to DXIL operations.
-
-class DXILType : LLVMType<OtherVT> {
-  let isAny = 1;
-  int isI16OrI32 = 0;
-  int isHalfOrFloat = 0;
-}
-
-// Concrete records for various overload types supported specifically by
-// DXIL Operations.
-let isI16OrI32 = 1 in
-  def llvm_i16ori32_ty : DXILType;
-
-let isHalfOrFloat = 1 in
-  def llvm_halforfloat_ty : DXILType;
-
-// Abstraction DXIL Operation to LLVM intrinsic
-class DXILOpMappingBase {
-  int OpCode = 0;                      // Opcode of DXIL Operation
-  DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
-  Intrinsic LLVMIntrinsic = ?;         // LLVM Intrinsic DXIL Operation maps to
-  string Doc = "";                     // A short description of the operation
-  list<LLVMType> OpTypes = ?;          // Valid types of DXIL Operation in the
-                                       // format [returnTy, param1ty, ...]
-}
-
-class DXILOpMapping<int opCode, DXILOpClass opClass,
-                    Intrinsic intrinsic, string doc,
-                    list<LLVMType> opTys = []> : DXILOpMappingBase {
-  int OpCode = opCode;                 // Opcode corresponding to DXIL Operation
-  DXILOpClass OpClass = opClass;       // Class of DXIL Operation.
-  Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
-  string Doc = doc;                    // to a short description of the operation
-  list<LLVMType> OpTypes = !if(!eq(!size(opTys), 0), LLVMIntrinsic.Types, opTys);
-}
-
-// Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
-def Abs : DXILOpMapping<6, unary, int_fabs,
-                         "Returns the absolute value of the input.">;
-def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf,
-                         "Determines if the specified value is infinite.",
-                         [llvm_i1_ty, llvm_halforfloat_ty]>;
-def Cos  : DXILOpMapping<12, unary, int_cos,
-                         "Returns cosine(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Sin  : DXILOpMapping<13, unary, int_sin,
-                         "Returns sine(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Tan  : DXILOpMapping<14, unary, int_tan,
-                         "Returns tangent(theta) for theta in radians.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def ACos  : DXILOpMapping<15, unary, int_acos,
-                         "Returns the arccosine of each component of input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def ASin  : DXILOpMapping<16, unary, int_asin,
-                         "Returns the arcsine of each component of input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def ATan  : DXILOpMapping<17, unary, int_atan,
-                         "Returns the arctangent of each component of input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def HCos  : DXILOpMapping<18, unary, int_cosh,
-                         "Returns the hyperbolic cosine of the specified value.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def HSin  : DXILOpMapping<19, unary, int_sinh,
-                         "Returns the hyperbolic sine of the specified value.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def HTan  : DXILOpMapping<20, unary, int_tanh,
-                         "Returns the hyperbolic tan of the specified value.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-
-def Exp2 : DXILOpMapping<21, unary, int_exp2,
-                         "Returns the base 2 exponential, or 2**x, of the specified value."
-                         "exp2(x) = 2**x.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Frac : DXILOpMapping<22, unary, int_dx_frac,
-                         "Returns a fraction from 0 to 1 that represents the "
-                         "decimal part of the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Log2 : DXILOpMapping<23, unary, int_log2,
-                         "Returns the base-2 logarithm of the specified value.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Sqrt : DXILOpMapping<24, unary, int_sqrt,
-                         "Returns the square root of the specified floating-point"
-                         "value, per component.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
-                         "Returns the reciprocal of the square root of the specified value."
-                         "rsqrt(x) = 1 / sqrt(x).",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Round : DXILOpMapping<26, unary, int_roundeven,
-                         "Returns the input rounded to the nearest integer"
-                         "within a floating-point type.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Floor : DXILOpMapping<27, unary, int_floor,
-                         "Returns the largest integer that is less than or equal to the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Ceil : DXILOpMapping<28, unary, int_ceil,
-                         "Returns the smallest integer that is greater than or equal to the input.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Trunc : DXILOpMapping<29, unary, int_trunc,
-                         "Returns the specified value truncated to the integer component.",
-                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
-def Rbits : DXILOpMapping<30, unary, int_bitreverse,
-                         "Returns the specified value with its bits reversed.",
-                         [llvm_anyint_ty, LLVMMatchType<0>]>;
-def FMax : DXILOpMapping<35, binary, int_maxnum,
-                         "Float maximum. FMax(a,b) = a > b ? a : b">;
-def FMin : DXILOpMapping<36, binary, int_minnum,
-                         "Float minimum. FMin(a,b) = a < b ? a : b">;
-def SMax : DXILOpMapping<37, binary, int_smax,
-                         "Signed integer maximum. SMax(a,b) = a > b ? a : b">;
-def SMin : DXILOpMapping<38, binary, int_smin,
-                         "Signed integer minimum. SMin(a,b) = a < b ? a : b">;
-def UMax : DXILOpMapping<39, binary, int_umax,
-                         "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
-def UMin : DXILOpMapping<40, binary, int_umin,
-                         "Unsigned integer minimum. UMin(a,b) = a < b ? a : b">;
-def FMad : DXILOpMapping<46, tertiary, int_fmuladd,
-                         "Floating point arithmetic multiply/add operation. fmad(m,a,b) = m * a + b.">;
-def IMad : DXILOpMapping<48, tertiary, int_dx_imad,
-                         "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">;
-def UMad : DXILOpMapping<49, tertiary, int_dx_umad,
-                         "Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">;
-let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)) in
-  def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2,
-                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1">;
-let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)) in
-  def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3,
-                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2">;
-let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)) in
-  def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4,
-                           "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3">;
-def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
-                             "Reads the thread ID">;
-def GroupId  : DXILOpMapping<94, groupId, int_dx_group_id,
-                             "Reads the group ID (SV_GroupID)">;
-def ThreadIdInGroup : DXILOpMapping<95, threadIdInGroup,
-                                    int_dx_thread_id_in_group,
-                                    "Reads the thread ID within the group "
-                                    "(SV_GroupThreadID)">;
-def FlattenedThreadIdInGroup : DXILOpMapping<96, flattenedThreadIdInGroup,
-                                             int_dx_flattened_thread_id_in_group,
-                                             "Provides a flattened index for a "
-                                             "given thread within a given "
-                                             "group (SV_GroupIndex)">;
+// Shader stages
+class ShaderStage;
+
+defset list<ShaderStage> ShaderStages = {
+  def compute : ShaderStage;
+  def domain : ShaderStage;
+  def hull : ShaderStage;
+  def pixel : ShaderStage;
+  def vertex : ShaderStage;
+  def geometry : ShaderStage;
+  def library : ShaderStage;
+  def amplification : ShaderStage;
+  def mesh : ShaderStage;
+  def node : ShaderStage;
+  def raygeneration : ShaderStage;
+  def intersection : ShaderStage;
+  def allKinds : ShaderStage;
+}
+
+// Primitive predicate
+class Pred;
+
+// Shader Model version predicate. This translates to
+// a check for specified shader model version
+class SMVersion<Version ver> : Pred {
+  Version sm_version = ver;
+}
+
+// Class abstraction of constraints predicated on Shader Model version
+class SMVersionConstraints<Version ver, dag oloads, dag stages> : SMVersion<ver> {
+  dag overload_types = oloads;
+  dag stage_kinds = stages;
+}
+
+// Marker used to identify argument list.
+def ins;
+
+// Marker used to identify result list.
+def out;
+
+// Marker used to identify list of shader model based attributes.
+def sm_attrs;
+
+// Marker used to identify overload types list.
+def overloads;
+
+// Marker used to identify stage kinds list.
+def stages;
+
+// Marker used to identify attribute list.
+def attrs;
+
+// Abstraction DXIL Operation
+class DXILOp {
+  // A short description of the operation
+  string Doc = "";
+
+  // Opcode of DXIL Operation
+  int OpCode = 0;
+
+  // Class of DXIL Operation.
+  DXILOpClass OpClass = UnknownOpClass;
+
+  // LLVM Intrinsic DXIL Operation maps to
+  Intrinsic LLVMIntrinsic = ?;
+
+  // TODO : DELETE THIS once support in DXILEmitter is added to consume
+  // overload_types and generate appropriate code.
+  // Valid overload type of DXIL Operation
+  list<LLVMType> OpOverloadTypes = ?;
+
+  // Dag containing the arguments of the op. Default to 0 arguments.
+  dag arguments = (ins);
+
+  // Results of the op. Default to 0 results.
+  dag result = (out);
+
+  // List of constraints predicated on Shader Model version
+  // This field is required to be specified. If a DXIL Op has no
+  // overloads or stages predicated on Shader Model version, the
+  // minimum Shader Model version the DXIL Op is supported it
+  // should be specified as a single list item
+  //       [SMVersionConstraints<SMX_Y, (overloads), (stages allKinds)]
+  // If the DXIL Op is a DXIL Op that is not predicted on Shader
+  // Model version,  it should be specified as an empty list.
+
+  list<SMVersionConstraints> sm_constraints;
+
+  // Non-predicated operation attributes
+  dag attributes = (attrs);
+  Version DXILVersion = ?;
+}
+
+// Concrete definitions of DXIL Operations
+
+def IsInf : DXILOp {
+  let Doc = "Determines if the specified value is infinite.";
+  let OpCode = 9;
+  let OpClass = isSpecialFloat;
+  let LLVMIntrinsic = int_dx_isinf;
+  let OpOverloadTypes = [llvm_half_ty, llvm_float_ty];
+  let arguments = (ins llvm_anyfloat_ty);
+  let result = (out llvm_i1_ty);
+  let sm_constraints = [SMVersionConstraints<SM6_0,
+                             (overloads llvm_half_ty, llvm_float_ty),
+                             (stages allKinds)>];
+}
+
+def Abs : DXILOp {
+  let Doc = "Returns the absolute value of the input.";
+  let OpCode = 6;
+  let OpClass = unary;
+  let LLVMIntrinsic = int_fabs;
+  let OpOverloadTypes = [llvm_half_ty, llvm_float_ty, llvm_double_ty];
+  let arguments = (ins LLVMMatchType<0>);
+  let result = (out dxil_overload_ty);
+  let sm_constraints = [SMVersionConstraints<SM6_0,
+                             (overloads llvm_half_ty, llvm_float_ty, llvm_double_ty),
+                             (stages allKinds)>];
+}
+
+def Cos  : DXILOp {
+  let Doc ="Returns cosine(theta) for theta in radians.";
+  let OpCode = 12;
+  let OpClass = unary;
+  let LLVMIntrinsic = int_cos;
+  let OpOverloadTypes = [llvm_half_ty, llvm_float_ty];
+  let arguments = (ins LLVMMatchType<0>);
+  let result = (out dxil_overload_ty);
+  let sm_constraints = [SMVersionConstraints<SM6_0,
+                             (overloads llvm_half_ty, llvm_float_ty),
+                             (stages allKinds)>];
+}
+
+def Sin  : DXILOp {
+  let Doc ="Returns sine(theta) for theta in radians.";
+  let OpCode = 13;
+  let OpClass = unary;
+  let LLVMIntrinsic = int_sin;
+  let OpOverloadTypes = [llvm_half_ty, llvm_float_ty];
+  let arguments = (ins LLVMMatchType<0>);
+  let result = (out dxil_overload_ty);
+  let sm_constraints = [SMVersionConstraints<SM6_0,
+                          (overloads llvm_half_ty, llvm_float_ty),
+                          (stages allKinds)>];
+  let attributes = (attrs ReadNone);
----------------
bharadwajy wrote:

> Why is this the only op with any attributes? I can't imagine `sin` being different from `cos` here...

Added attributes to other DXIL Ops.

https://github.com/llvm/llvm-project/pull/97593


More information about the llvm-commits mailing list