[llvm] [DirectX][NFC] Leverage LLVM and DirectX intrinsic description in DXIL Op records (PR #83193)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 28 11:30:25 PST 2024


https://github.com/bharadwajy updated https://github.com/llvm/llvm-project/pull/83193

>From 3ea3ad12abed583ce191acade7229e487ccd12b5 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 23 Feb 2024 19:12:12 -0500
Subject: [PATCH 1/3] [DirectX][NFC] Simplified DXIL Operation mapping to LLVM
 or DirectX intrinsics in DXIL.td.

Updated DXILEmitter backend to consume the change in the TableGen record specification.
Updated DXILOpBuilder accordingly.
Ensured that corresponding lit tests pass.
---
 llvm/lib/Target/DirectX/DXIL.td            | 146 ++---------
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp  |  25 +-
 llvm/lib/Target/DirectX/DXILOpBuilder.h    |   3 +-
 llvm/lib/Target/DirectX/DXILOpLowering.cpp |   3 +-
 llvm/test/CodeGen/DirectX/comput_ids.ll    |   8 +-
 llvm/utils/TableGen/DXILEmitter.cpp        | 281 ++++++++-------------
 6 files changed, 150 insertions(+), 316 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 8a3454c89542ce..447887fbd474f8 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -12,139 +12,23 @@
 //===----------------------------------------------------------------------===//
 
 include "llvm/IR/Intrinsics.td"
-include "llvm/IR/Attributes.td"
-
-// Abstract representation of the class a DXIL Operation belongs to.
-class DXILOpClass<string name> {
-  string Name = name;
-}
-
-// Abstract representation of the category a DXIL Operation belongs to
-class DXILOpCategory<string name> {
-  string Name = name;
-}
-
-def UnaryClass : DXILOpClass<"Unary">;
-def BinaryClass : DXILOpClass<"Binary">;
-def FlattenedThreadIdInGroupClass : DXILOpClass<"FlattenedThreadIdInGroup">;
-def ThreadIdInGroupClass : DXILOpClass<"ThreadIdInGroup">;
-def ThreadIdClass : DXILOpClass<"ThreadId">;
-def GroupIdClass : DXILOpClass<"GroupId">;
-
-def BinaryUintCategory : DXILOpCategory<"Binary uint">;
-def UnaryFloatCategory : DXILOpCategory<"Unary float">;
-def ComputeIDCategory : DXILOpCategory<"Compute/Mesh/Amplification shader">;
-
-// Represent as any pointer type with an option to change to a qualified pointer
-// type with address space specified.
-def dxil_handle_ty  : LLVMAnyPointerType;
-def dxil_cbuffer_ty : LLVMAnyPointerType;
-def dxil_resource_ty : LLVMAnyPointerType;
-
-// The parameter description for a DXIL operation
-class DXILOpParameter<int pos, LLVMType type, string name, string doc,
-                 bit isConstant = 0, string enumName = "",
-                 int maxValue = 0> {
-  int Pos = pos;               // Position in parameter list
-  LLVMType ParamType = type;   // Parameter type
-  string Name = name;          // Short, unique parameter name
-  string Doc = doc;            // Description of this parameter
-  bit IsConstant = isConstant; // Whether this parameter requires a constant value in the IR
-  string EnumName = enumName;  // Name of the enum type, if applicable
-  int MaxValue = maxValue;     // Maximum value for this parameter, if applicable
-}
-
-// A representation for a DXIL operation
-class DXILOperationDesc {
-  string OpName = "";         // Name of DXIL operation
-  int OpCode = 0;             // Unique non-negative integer associated with the operation
-  DXILOpClass  OpClass;       // Class of the operation
-  DXILOpCategory OpCategory;  // Category of the operation
-  string Doc = "";            // Description of the operation
-  list<DXILOpParameter> Params = []; // Parameter list of the operation
-  list<LLVMType> OverloadTypes = [];  // Overload types, if applicable
-  EnumAttr Attribute;         // Operation Attribute. Leverage attributes defined in Attributes.td
-                              // ReadNone - operation does not access memory.
-                              // ReadOnly - only reads from memory.
-                              // "ReadMemory"   - reads memory
-  bit IsDerivative = 0;       // Whether this is some kind of derivative
-  bit IsGradient = 0;         // Whether this requires a gradient calculation
-  bit IsFeedback = 0;         // Whether this is a sampler feedback operation
-  bit IsWave = 0;             // Whether this requires in-wave, cross-lane functionality
-  bit NeedsUniformInputs = 0; // Whether this operation requires that all
-                              // of its inputs are uniform across the wave
-  // Group DXIL operation for stats - e.g., to accumulate the number of atomic/float/uint/int/...
-  // operations used in the program.
-  list<string> StatsGroup = [];
-}
-
-class DXILOperation<string name, int opCode, DXILOpClass opClass, DXILOpCategory opCategory, string doc,
-              list<LLVMType> oloadTypes, EnumAttr attrs, list<DXILOpParameter> params,
-              list<string> statsGroup = []> : DXILOperationDesc {
-  let OpName = name;
-  let OpCode = opCode;
-  let Doc = doc;
-  let Params = params;
-  let OpClass = opClass;
-  let OpCategory = opCategory;
-  let OverloadTypes = oloadTypes;
-  let Attribute = attrs;
-  let StatsGroup = statsGroup;
-}
 
 // LLVM intrinsic that DXIL operation maps to.
 class LLVMIntrinsic<Intrinsic llvm_intrinsic_> { Intrinsic llvm_intrinsic = llvm_intrinsic_; }
 
-def Sin : DXILOperation<"Sin", 13, UnaryClass, UnaryFloatCategory, "returns sine(theta) for theta in radians.",
-  [llvm_half_ty, llvm_float_ty], ReadNone,
-  [
-    DXILOpParameter<0, llvm_anyfloat_ty, "", "operation result">,
-    DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
-    DXILOpParameter<2, llvm_anyfloat_ty, "value", "input value">
-  ],
-  ["floats"]>,
-  LLVMIntrinsic<int_sin>;
-
-def UMax : DXILOperation< "UMax", 39, BinaryClass, BinaryUintCategory, "unsigned integer maximum. UMax(a,b) = a > b ? a : b",
-    [llvm_i16_ty, llvm_i32_ty, llvm_i64_ty], ReadNone,
-  [
-    DXILOpParameter<0, llvm_anyint_ty, "", "operation result">,
-    DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
-    DXILOpParameter<2, llvm_anyint_ty, "a", "input value">,
-    DXILOpParameter<3, llvm_anyint_ty, "b", "input value">
-  ],
-  ["uints"]>,
-  LLVMIntrinsic<int_umax>;
-
-def ThreadId : DXILOperation< "ThreadId", 93, ThreadIdClass, ComputeIDCategory, "reads the thread ID", [llvm_i32_ty], ReadNone,
-  [
-    DXILOpParameter<0, llvm_i32_ty, "", "thread ID component">,
-    DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
-    DXILOpParameter<2, llvm_i32_ty, "component", "component to read (x,y,z)">
-  ]>,
-  LLVMIntrinsic<int_dx_thread_id>;
-
-def GroupId : DXILOperation< "GroupId", 94, GroupIdClass, ComputeIDCategory, "reads the group ID (SV_GroupID)", [llvm_i32_ty], ReadNone,
-  [
-    DXILOpParameter<0, llvm_i32_ty, "", "group ID component">,
-    DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
-    DXILOpParameter<2, llvm_i32_ty, "component", "component to read">
-  ]>,
-  LLVMIntrinsic<int_dx_group_id>;
-
-def ThreadIdInGroup : DXILOperation< "ThreadIdInGroup", 95, ThreadIdInGroupClass, ComputeIDCategory,
-  "reads the thread ID within the group (SV_GroupThreadID)", [llvm_i32_ty], ReadNone,
-  [
-    DXILOpParameter<0, llvm_i32_ty, "", "thread ID in group component">,
-    DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">,
-    DXILOpParameter<2, llvm_i32_ty, "component", "component to read (x,y,z)">
-  ]>,
-  LLVMIntrinsic<int_dx_thread_id_in_group>;
+// Abstraction DXIL Operation to LLVM intrinsic
+class DXILOpMapping<int opCode, Intrinsic intrinsic, string doc> {
+  int OpCode = opCode;                 // Opcode corresponding to DXIL Operation
+  Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps to
+  string Doc = doc;                    // a short description of the operation
+}
 
-def FlattenedThreadIdInGroup : DXILOperation< "FlattenedThreadIdInGroup", 96, FlattenedThreadIdInGroupClass, ComputeIDCategory,
-   "provides a flattened index for a given thread within a given group (SV_GroupIndex)", [llvm_i32_ty], ReadNone,
-  [
-    DXILOpParameter<0, llvm_i32_ty, "", "result">,
-    DXILOpParameter<1, llvm_i32_ty, "opcode", "DXIL opcode">
-  ]>,
-  LLVMIntrinsic<int_dx_flattened_thread_id_in_group>;
+// Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
+def Sin      : DXILOpMapping<13, int_sin, "Returns sine(theta) for theta in radians.">;
+def UMax     : DXILOpMapping<39, int_umax, "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
+def ThreadId : DXILOpMapping<93, int_dx_thread_id, "Reads the thread ID">;
+def GroupId  : DXILOpMapping<94, int_dx_group_id, "Reads the group ID (SV_GroupID)">;
+def ThreadIdInGroup : DXILOpMapping<95, int_dx_thread_id_in_group,
+                                    "Reads the thread ID within the group (SV_GroupThreadID)">;
+def FlattenedThreadIdInGroup_New : DXILOpMapping<96, int_dx_flattened_thread_id_in_group,
+                                    "Provides a flattened index for a given thread within a given group (SV_GroupIndex)">;
\ No newline at end of file
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 42180a865b72e3..21a20d45b922d9 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -221,12 +221,26 @@ static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
   return nullptr;
 }
 
+/// Construct DXIL function type. This is the type of a function with
+/// the following prototype
+///     OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)
+/// <param-types> are constructed from types in Prop.
+/// \param Prop  Structure containing DXIL Operation properties based on
+///               its specification in DXIL.td.
+/// \param OverloadTy Return type to be used to construct DXIL function type.
 static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
                                            Type *OverloadTy) {
   SmallVector<Type *> ArgTys;
 
   auto ParamKinds = getOpCodeParameterKind(*Prop);
 
+  // Add OverloadTy as return type of the function
+  ArgTys.emplace_back(OverloadTy);
+
+  // Add DXIL Opcode value type viz., Int32 as first argument
+  ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext()));
+
+  // Add DXIL Operation parameter types as specified in DXIL properties
   for (unsigned I = 0; I < Prop->NumOfParameters; ++I) {
     ParameterKind Kind = ParamKinds[I];
     ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy));
@@ -267,13 +281,13 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
   return B.CreateCall(Fn, FullArgs);
 }
 
-Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT,
-                                   bool NoOpCodeParam) {
+Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
 
   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
+  // If DXIL Op has no overload parameter, just return the
+  // precise return type specified.
   if (Prop->OverloadParamIndex < 0) {
     auto &Ctx = FT->getContext();
-    // When only has 1 overload type, just return it.
     switch (Prop->OverloadTys) {
     case OverloadKind::VOID:
       return Type::getVoidTy(Ctx);
@@ -302,9 +316,8 @@ Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT,
   // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
   Type *OverloadType = FT->getReturnType();
   if (Prop->OverloadParamIndex != 0) {
-    // Skip Return Type and Type for DXIL opcode.
-    const unsigned SkipedParam = NoOpCodeParam ? 2 : 1;
-    OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam);
+    // Skip Return Type.
+    OverloadType = FT->getParamType(Prop->OverloadParamIndex - 1);
   }
 
   auto ParamKinds = getOpCodeParameterKind(*Prop);
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 940ed538c7ce15..1c15f109184adf 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -31,8 +31,7 @@ class DXILOpBuilder {
   DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {}
   CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
                              llvm::iterator_range<Use *> Args);
-  Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT,
-                      bool NoOpCodeParam);
+  Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
   static const char *getOpCodeName(dxil::OpCode DXILOp);
 
 private:
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index f6e2297e9af41f..6b649b76beecdf 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -33,8 +33,7 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
   IRBuilder<> B(M.getContext());
   Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
   DXILOpBuilder DXILB(M, B);
-  Type *OverloadTy =
-      DXILB.getOverloadTy(DXILOp, F.getFunctionType(), /*NoOpCodeParam*/ true);
+  Type *OverloadTy = DXILB.getOverloadTy(DXILOp, F.getFunctionType());
   for (User *U : make_early_inc_range(F.users())) {
     CallInst *CI = dyn_cast<CallInst>(U);
     if (!CI)
diff --git a/llvm/test/CodeGen/DirectX/comput_ids.ll b/llvm/test/CodeGen/DirectX/comput_ids.ll
index 553994094d71e5..c0ae5761b4970e 100644
--- a/llvm/test/CodeGen/DirectX/comput_ids.ll
+++ b/llvm/test/CodeGen/DirectX/comput_ids.ll
@@ -9,7 +9,7 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; Function Attrs: noinline nounwind optnone
 define i32 @test_thread_id(i32 %a) #0 {
 entry:
-; CHECK:call i32 @dx.op.threadId.i32(i32 93, i32 %{{.*}})
+; CHECK:call i32 @dx.op.unary.i32(i32 93, i32 %{{.*}})
   %0 = call i32 @llvm.dx.thread.id(i32 %a)
   ret i32 %0
 }
@@ -18,7 +18,7 @@ entry:
 ; Function Attrs: noinline nounwind optnone
 define i32 @test_group_id(i32 %a) #0 {
 entry:
-; CHECK: call i32 @dx.op.groupId.i32(i32 94, i32 %{{.*}})
+; CHECK: call i32 @dx.op.unary.i32(i32 94, i32 %{{.*}})
   %0 = call i32 @llvm.dx.group.id(i32 %a)
   ret i32 %0
 }
@@ -27,7 +27,7 @@ entry:
 ; Function Attrs: noinline nounwind optnone
 define i32 @test_thread_id_in_group(i32 %a) #0 {
 entry:
-; CHECK: call i32 @dx.op.threadIdInGroup.i32(i32 95, i32 %{{.*}})
+; CHECK: call i32 @dx.op.unary.i32(i32 95, i32 %{{.*}})
   %0 = call i32 @llvm.dx.thread.id.in.group(i32 %a)
   ret i32 %0
 }
@@ -36,7 +36,7 @@ entry:
 ; Function Attrs: noinline nounwind optnone
 define i32 @test_flattened_thread_id_in_group() #0 {
 entry:
-; CHECK: call i32 @dx.op.flattenedThreadIdInGroup.i32(i32 96)
+; CHECK: call i32 @dx.op.nullary.i32(i32 96)
   %0 = call i32 @llvm.dx.flattened.thread.id.in.group()
   ret i32 %0
 }
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index d47df597d53a35..a28830920eec21 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -13,6 +13,7 @@
 
 #include "SequenceToOffsetTable.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/StringSwitch.h"
@@ -30,28 +31,16 @@ struct DXILShaderModel {
   int Minor = 0;
 };
 
-struct DXILParameter {
-  int Pos; // position in parameter list
-  ParameterKind Kind;
-  StringRef Name; // short, unique name
-  StringRef Doc;  // the documentation description of this parameter
-  bool IsConst;   // whether this argument requires a constant value in the IR
-  StringRef EnumName; // the name of the enum type if applicable
-  int MaxValue;       // the maximum value for this parameter if applicable
-  DXILParameter(const Record *R);
-};
-
 struct DXILOperationDesc {
-  StringRef OpName;   // name of DXIL operation
+  std::string OpName; // name of DXIL operation
   int OpCode;         // ID of DXIL operation
   StringRef OpClass;  // name of the opcode class
-  StringRef Category; // classification for this instruction
   StringRef Doc;      // the documentation description of this instruction
 
-  SmallVector<DXILParameter> Params; // the operands that this instruction takes
-  SmallVector<ParameterKind> OverloadTypes; // overload types if applicable
-  StringRef Attr; // operation attribute; reference to string representation
-                  // of llvm::Attribute::AttrKind
+  SmallVector<std::string> OpTypeNames; // Vector of operand type name strings -
+                                        // return type is at index 0
+  SmallVector<std::string>
+      OpAttributes;     // operation attribute represented as strings
   StringRef Intrinsic;  // The llvm intrinsic map to OpName. Default is "" which
                         // means no map exists
   bool IsDeriv = false; // whether this is some kind of derivative
@@ -102,50 +91,67 @@ static ParameterKind lookupParameterKind(StringRef typeNameStr) {
   return paramKind;
 }
 
+/// Construct an object using the DXIL Operation records specified
+/// in DXIL.td. This serves as the single source of reference for
+/// C++ code generated by this TableGen backend.
 DXILOperationDesc::DXILOperationDesc(const Record *R) {
-  OpName = R->getValueAsString("OpName");
+  OpName = R->getNameInitAsString();
   OpCode = R->getValueAsInt("OpCode");
-  OpClass = R->getValueAsDef("OpClass")->getValueAsString("Name");
-  Category = R->getValueAsDef("OpCategory")->getValueAsString("Name");
 
-  if (R->getValue("llvm_intrinsic")) {
-    auto *IntrinsicDef = R->getValueAsDef("llvm_intrinsic");
+  Doc = R->getValueAsString("Doc");
+
+  if (R->getValue("LLVMIntrinsic")) {
+    auto *IntrinsicDef = R->getValueAsDef("LLVMIntrinsic");
     auto DefName = IntrinsicDef->getName();
     assert(DefName.starts_with("int_") && "invalid intrinsic name");
     // Remove the int_ from intrinsic name.
     Intrinsic = DefName.substr(4);
+    // NOTE: It is expected that return type and parameter types of
+    // DXIL Operation are the same as that of the intrinsic. Deviations
+    // are expected to be encoded in TableGen record specification and
+    // handled accordingly here. Support to be added later, as needed.
+    // Get parameter type list of the intrinsic. Types attribute contains
+    // the list of as [returnType, param1Type,, param2Type, ...]
+    auto TypeList = IntrinsicDef->getValueAsListInit("Types");
+    unsigned TypeListSize = TypeList->size();
+    OverloadParamIndex = -1;
+    // Populate return type and parameter type names
+    for (unsigned i = 0; i < TypeListSize; i++) {
+      OpTypeNames.emplace_back(TypeList->getElement(i)->getAsString());
+      // Get the overload parameter index.
+      // REVISIT : Seems hacky. Is it possible that more than one parameter can
+      // be of overload kind?? REVISIT-2: Check for any additional constraints
+      // specified for DXIL operation restricting return type.
+      if (i > 0) {
+        auto &CurParam = OpTypeNames.back();
+        if (lookupParameterKind(CurParam) >= ParameterKind::OVERLOAD) {
+          OverloadParamIndex = i;
+        }
+      }
+    }
+    // Determine the operation class (unary/binary) based on the number of
+    // parameters As parameter types are being considered, skip return type
+    auto ParamSize = TypeListSize - 1;
+    if (ParamSize == 0) {
+      OpClass = "Nullary";
+    } else if (ParamSize == 1) {
+      OpClass = "Unary";
+    } else if (ParamSize == 2) {
+      OpClass = "Binary";
+    } else {
+      // TODO: Extend as needed
+      llvm_unreachable("Unhandled parameter size");
+    }
+    // NOTE: For now, assume that attributes of DXIL Operation are the same as
+    // that of the intrinsic. Deviations are expected to be encoded in TableGen
+    // record specification and handled accordingly here. Support to be added
+    // later.
+    auto IntrPropList = IntrinsicDef->getValueAsListInit("IntrProperties");
+    auto IntrPropListSize = IntrPropList->size();
+    for (unsigned i = 0; i < IntrPropListSize; i++) {
+      OpAttributes.emplace_back(IntrPropList->getElement(i)->getAsString());
+    }
   }
-
-  Doc = R->getValueAsString("Doc");
-
-  ListInit *ParamList = R->getValueAsListInit("Params");
-  OverloadParamIndex = -1;
-  for (unsigned I = 0; I < ParamList->size(); ++I) {
-    Record *Param = ParamList->getElementAsRecord(I);
-    Params.emplace_back(DXILParameter(Param));
-    auto &CurParam = Params.back();
-    if (CurParam.Kind >= ParameterKind::OVERLOAD)
-      OverloadParamIndex = I;
-  }
-  ListInit *OverloadTypeList = R->getValueAsListInit("OverloadTypes");
-
-  for (unsigned I = 0; I < OverloadTypeList->size(); ++I) {
-    Record *R = OverloadTypeList->getElementAsRecord(I);
-    OverloadTypes.emplace_back(lookupParameterKind(R->getNameInitAsString()));
-  }
-  Attr = StringRef(R->getValue("Attribute")->getNameInitAsString());
-}
-
-DXILParameter::DXILParameter(const Record *R) {
-  Name = R->getValueAsString("Name");
-  Pos = R->getValueAsInt("Pos");
-  Kind =
-      lookupParameterKind(R->getValue("ParamType")->getValue()->getAsString());
-  if (R->getValue("Doc"))
-    Doc = R->getValueAsString("Doc");
-  IsConst = R->getValueAsBit("IsConstant");
-  EnumName = R->getValueAsString("EnumName");
-  MaxValue = R->getValueAsInt("MaxValue");
 }
 
 static std::string parameterKindToString(ParameterKind Kind) {
@@ -187,82 +193,31 @@ static void emitDXILOpEnum(DXILOperationDesc &Op, raw_ostream &OS) {
   OS << Op.OpName << " = " << Op.OpCode << ", // " << Op.Doc << "\n";
 }
 
-static std::string buildCategoryStr(StringSet<> &Cetegorys) {
-  std::string Str;
-  raw_string_ostream OS(Str);
-  for (auto &It : Cetegorys) {
-    OS << " " << It.getKey();
-  }
-  return OS.str();
-}
-
-// Emit enum declaration for DXIL.
 static void emitDXILEnums(std::vector<DXILOperationDesc> &Ops,
                           raw_ostream &OS) {
-  // Sort by Category + OpName.
+  // Sort by OpCode
   llvm::sort(Ops, [](DXILOperationDesc &A, DXILOperationDesc &B) {
-    // Group by Category first.
-    if (A.Category == B.Category)
-      // Inside same Category, order by OpName.
-      return A.OpName < B.OpName;
-    else
-      return A.Category < B.Category;
+    return A.OpCode < B.OpCode;
   });
 
   OS << "// Enumeration for operations specified by DXIL\n";
   OS << "enum class OpCode : unsigned {\n";
 
-  StringMap<StringSet<>> ClassMap;
-  StringRef PrevCategory = "";
   for (auto &Op : Ops) {
-    StringRef Category = Op.Category;
-    if (Category != PrevCategory) {
-      OS << "\n// " << Category << "\n";
-      PrevCategory = Category;
-    }
     emitDXILOpEnum(Op, OS);
-    auto It = ClassMap.find(Op.OpClass);
-    if (It != ClassMap.end()) {
-      It->second.insert(Op.Category);
-    } else {
-      ClassMap[Op.OpClass].insert(Op.Category);
-    }
   }
 
   OS << "\n};\n\n";
 
-  std::vector<std::pair<std::string, std::string>> ClassVec;
-  for (auto &It : ClassMap) {
-    ClassVec.emplace_back(
-        std::pair(It.getKey().str(), buildCategoryStr(It.second)));
-  }
-  // Sort by Category + ClassName.
-  llvm::sort(ClassVec, [](std::pair<std::string, std::string> &A,
-                          std::pair<std::string, std::string> &B) {
-    StringRef ClassA = A.first;
-    StringRef CategoryA = A.second;
-    StringRef ClassB = B.first;
-    StringRef CategoryB = B.second;
-    // Group by Category first.
-    if (CategoryA == CategoryB)
-      // Inside same Category, order by ClassName.
-      return ClassA < ClassB;
-    else
-      return CategoryA < CategoryB;
-  });
-
   OS << "// Groups for DXIL operations with equivalent function templates\n";
   OS << "enum class OpCodeClass : unsigned {\n";
-  PrevCategory = "";
-  for (auto &It : ClassVec) {
-
-    StringRef Category = It.second;
-    if (Category != PrevCategory) {
-      OS << "\n// " << Category << "\n";
-      PrevCategory = Category;
-    }
-    StringRef Name = It.first;
-    OS << Name << ",\n";
+  // Build an OpClass set to print
+  SmallSet<StringRef, 2> OpClassSet;
+  for (auto &Op : Ops) {
+    OpClassSet.insert(Op.OpClass);
+  }
+  for (auto &C : OpClassSet) {
+    OS << C << ",\n";
   }
   OS << "\n};\n\n";
 }
@@ -291,49 +246,37 @@ static void emitDXILIntrinsicMap(std::vector<DXILOperationDesc> &Ops,
  @param Attr string reference
  @return std::string Attribute enum string
  */
-static std::string emitDXILOperationAttr(StringRef Attr) {
-  return StringSwitch<std::string>(Attr)
-      .Case("ReadNone", "Attribute::ReadNone")
-      .Case("ReadOnly", "Attribute::ReadOnly")
-      .Default("Attribute::None");
-}
-
-static std::string overloadKindStr(ParameterKind Overload) {
-  switch (Overload) {
-  case ParameterKind::HALF:
-    return "OverloadKind::HALF";
-  case ParameterKind::FLOAT:
-    return "OverloadKind::FLOAT";
-  case ParameterKind::DOUBLE:
-    return "OverloadKind::DOUBLE";
-  case ParameterKind::I1:
-    return "OverloadKind::I1";
-  case ParameterKind::I8:
-    return "OverloadKind::I8";
-  case ParameterKind::I16:
-    return "OverloadKind::I16";
-  case ParameterKind::I32:
-    return "OverloadKind::I32";
-  case ParameterKind::I64:
-    return "OverloadKind::I64";
-  case ParameterKind::VOID:
-    return "OverloadKind::VOID";
-  default:
-    return "OverloadKind::UNKNOWN";
+static std::string emitDXILOperationAttr(SmallVector<std::string> Attrs) {
+  for (auto Attr : Attrs) {
+    // For now just recognize IntrNoMem and IntrReadMem as valid and ignore
+    // others
+    if (Attr == "IntrNoMem") {
+      return "Attribute::ReadNone";
+    } else if (Attr == "IntrReadMem") {
+      return "Attribute::ReadOnly";
+    }
   }
+  return "Attribute::None";
 }
 
-static std::string
-getDXILOperationOverloads(SmallVector<ParameterKind> Overloads) {
-  // Format is: OverloadKind::FLOAT | OverloadKind::HALF
-  auto It = Overloads.begin();
-  std::string Result;
-  raw_string_ostream OS(Result);
-  OS << overloadKindStr(*It);
-  for (++It; It != Overloads.end(); ++It) {
-    OS << " | " << overloadKindStr(*It);
-  }
-  return OS.str();
+static std::string emitOverloadKindStr(std::string OpTypeStr) {
+  std::string Result =
+      StringSwitch<std::string>(OpTypeStr)
+          .Case("llvm_i16_ty", "OverloadKind::I16")
+          .Case("llvm_i32_ty", "OverloadKind::I32")
+          .Case("llvm_i64_ty", "OverloadKind::I64")
+          .Case("llvm_anyint_ty",
+                "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64")
+          .Case("llvm_half_ty", "OverloadKind::HALF")
+          .Case("llvm_float_ty", "OverloadKind::FLOAT")
+          .Case("llvm_double_ty", "OverloadKind::DOUBLE")
+          .Case(
+              "llvm_anyfloat_ty",
+              "OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE")
+          .Default("UNHANDLED_TYPE");
+
+  assert(Result != "UNHANDLED_TYPE" && "Unhandled parameter type");
+  return Result;
 }
 
 static std::string lowerFirstLetter(StringRef Name) {
@@ -369,15 +312,16 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
   StringMap<SmallVector<ParameterKind>> ParameterMap;
   StringSet<> ClassSet;
   for (auto &Op : Ops) {
-    OpStrings.add(Op.OpName.str());
+    OpStrings.add(Op.OpName);
 
     if (ClassSet.contains(Op.OpClass))
       continue;
     ClassSet.insert(Op.OpClass);
     OpClassStrings.add(getDXILOpClassName(Op.OpClass));
     SmallVector<ParameterKind> ParamKindVec;
-    for (auto &Param : Op.Params) {
-      ParamKindVec.emplace_back(Param.Kind);
+    // ParamKindVec is a vector of parameters. Skip return type at index 0
+    for (unsigned i = 1; i < Op.OpTypeNames.size(); i++) {
+      ParamKindVec.emplace_back(lookupParameterKind(Op.OpTypeNames[i]));
     }
     ParameterMap[Op.OpClass] = ParamKindVec;
     Parameters.add(ParamKindVec);
@@ -398,12 +342,12 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
 
   OS << "  static const OpCodeProperty OpCodeProps[] = {\n";
   for (auto &Op : Ops) {
-    OS << "  { dxil::OpCode::" << Op.OpName << ", "
-       << OpStrings.get(Op.OpName.str()) << ", OpCodeClass::" << Op.OpClass
-       << ", " << OpClassStrings.get(getDXILOpClassName(Op.OpClass)) << ", "
-       << getDXILOperationOverloads(Op.OverloadTypes) << ", "
-       << emitDXILOperationAttr(Op.Attr) << ", " << Op.OverloadParamIndex
-       << ", " << Op.Params.size() << ", "
+    OS << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
+       << ", OpCodeClass::" << Op.OpClass << ", "
+       << OpClassStrings.get(getDXILOpClassName(Op.OpClass)) << ", "
+       << emitOverloadKindStr(Op.OpTypeNames[0]) << ", "
+       << emitDXILOperationAttr(Op.OpAttributes) << ", "
+       << Op.OverloadParamIndex << ", " << Op.OpTypeNames.size() - 1 << ", "
        << Parameters.get(ParameterMap[Op.OpClass]) << " },\n";
   }
   OS << "  };\n";
@@ -460,29 +404,24 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
 }
 
 static void EmitDXILOperation(RecordKeeper &Records, raw_ostream &OS) {
-  std::vector<Record *> Ops = Records.getAllDerivedDefinitions("DXILOperation");
   OS << "// Generated code, do not edit.\n";
   OS << "\n";
-
+  // Get all DXIL Ops to intrinsic mapping records
+  std::vector<Record *> OpIntrMaps =
+      Records.getAllDerivedDefinitions("DXILOpMapping");
   std::vector<DXILOperationDesc> DXILOps;
-  DXILOps.reserve(Ops.size());
-  for (auto *Record : Ops) {
+  for (auto *Record : OpIntrMaps) {
     DXILOps.emplace_back(DXILOperationDesc(Record));
   }
-
   OS << "#ifdef DXIL_OP_ENUM\n";
   emitDXILEnums(DXILOps, OS);
   OS << "#endif\n\n";
-
   OS << "#ifdef DXIL_OP_INTRINSIC_MAP\n";
   emitDXILIntrinsicMap(DXILOps, OS);
   OS << "#endif\n\n";
-
   OS << "#ifdef DXIL_OP_OPERATION_TABLE\n";
   emitDXILOperationTable(DXILOps, OS);
   OS << "#endif\n\n";
-
-  OS << "\n";
 }
 
 static TableGen::Emitter::Opt X("gen-dxil-operation", EmitDXILOperation,

>From 9f2cc515b74efa490b51f1cdc147f2d74996a131 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 28 Feb 2024 10:33:46 -0500
Subject: [PATCH 2/3] PR feedback: Format and comment cleanup

---
 llvm/lib/Target/DirectX/DXIL.td     | 27 ++++++++++++++++++---------
 llvm/utils/TableGen/DXILEmitter.cpp | 13 +++++++------
 2 files changed, 25 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 447887fbd474f8..2f1ffc33190fda 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -14,21 +14,30 @@
 include "llvm/IR/Intrinsics.td"
 
 // LLVM intrinsic that DXIL operation maps to.
-class LLVMIntrinsic<Intrinsic llvm_intrinsic_> { Intrinsic llvm_intrinsic = llvm_intrinsic_; }
+class LLVMIntrinsic<Intrinsic llvm_intrinsic_> {
+                                  Intrinsic llvm_intrinsic = llvm_intrinsic_;
+                                 }
 
 // Abstraction DXIL Operation to LLVM intrinsic
 class DXILOpMapping<int opCode, Intrinsic intrinsic, string doc> {
   int OpCode = opCode;                 // Opcode corresponding to DXIL Operation
-  Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps to
-  string Doc = doc;                    // a short description of the operation
+  Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
+  string Doc = doc;                    // to a short description of the operation
 }
 
 // Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
-def Sin      : DXILOpMapping<13, int_sin, "Returns sine(theta) for theta in radians.">;
-def UMax     : DXILOpMapping<39, int_umax, "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
+def Sin  : DXILOpMapping<13, int_sin,
+                         "Returns sine(theta) for theta in radians.">;
+def UMax : DXILOpMapping<39, int_umax,
+                         "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
 def ThreadId : DXILOpMapping<93, int_dx_thread_id, "Reads the thread ID">;
-def GroupId  : DXILOpMapping<94, int_dx_group_id, "Reads the group ID (SV_GroupID)">;
+def GroupId  : DXILOpMapping<94, int_dx_group_id,
+                             "Reads the group ID (SV_GroupID)">;
 def ThreadIdInGroup : DXILOpMapping<95, int_dx_thread_id_in_group,
-                                    "Reads the thread ID within the group (SV_GroupThreadID)">;
-def FlattenedThreadIdInGroup_New : DXILOpMapping<96, int_dx_flattened_thread_id_in_group,
-                                    "Provides a flattened index for a given thread within a given group (SV_GroupIndex)">;
\ No newline at end of file
+                                    "Reads the thread ID within the group "
+                                    "(SV_GroupThreadID)">;
+def FlattenedThreadIdInGroup : DXILOpMapping<96,
+                                             int_dx_flattened_thread_id_in_group,
+                                             "Provides a flattened index for a "
+                                             "given thread within a given "
+                                             "group (SV_GroupIndex)">;
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index a28830920eec21..1f7526fd82ddc3 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -106,7 +106,7 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
     assert(DefName.starts_with("int_") && "invalid intrinsic name");
     // Remove the int_ from intrinsic name.
     Intrinsic = DefName.substr(4);
-    // NOTE: It is expected that return type and parameter types of
+    // TODO: It is expected that return type and parameter types of
     // DXIL Operation are the same as that of the intrinsic. Deviations
     // are expected to be encoded in TableGen record specification and
     // handled accordingly here. Support to be added later, as needed.
@@ -119,9 +119,10 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
     for (unsigned i = 0; i < TypeListSize; i++) {
       OpTypeNames.emplace_back(TypeList->getElement(i)->getAsString());
       // Get the overload parameter index.
-      // REVISIT : Seems hacky. Is it possible that more than one parameter can
-      // be of overload kind?? REVISIT-2: Check for any additional constraints
-      // specified for DXIL operation restricting return type.
+      // TODO : Seems hacky. Is it possible that more than one parameter can
+      // be of overload kind??
+      // TODO: Check for any additional constraints specified for DXIL operation
+      // restricting return type.
       if (i > 0) {
         auto &CurParam = OpTypeNames.back();
         if (lookupParameterKind(CurParam) >= ParameterKind::OVERLOAD) {
@@ -248,8 +249,8 @@ static void emitDXILIntrinsicMap(std::vector<DXILOperationDesc> &Ops,
  */
 static std::string emitDXILOperationAttr(SmallVector<std::string> Attrs) {
   for (auto Attr : Attrs) {
-    // For now just recognize IntrNoMem and IntrReadMem as valid and ignore
-    // others
+    // TODO: For now just recognize IntrNoMem and IntrReadMem as valid and
+    //  ignore others.
     if (Attr == "IntrNoMem") {
       return "Attribute::ReadNone";
     } else if (Attr == "IntrReadMem") {

>From c47f3b6a1384cc4e43ec5cba102666eb7ce286fa Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 28 Feb 2024 14:19:04 -0500
Subject: [PATCH 3/3] [DirectX][NFC] Add DXILOPClass attribute to DXILOpMapping

A unique OpClass string is associated with each of the DXIL operations.
This is needed to construct a valid DXIL operation function name to
be called in the lowered DXIL code.
---
 llvm/lib/Target/DirectX/DXIL.td         | 34 ++++++++++++++++++++-----
 llvm/test/CodeGen/DirectX/comput_ids.ll |  8 +++---
 llvm/utils/TableGen/DXILEmitter.cpp     | 16 +++---------
 3 files changed, 34 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 2f1ffc33190fda..c37609562595c6 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -13,30 +13,50 @@
 
 include "llvm/IR/Intrinsics.td"
 
+class DXILOpClass;
+
+// Following is a set of DXIL Operation classes whose names appear to be
+// arbitrary, yet need to be a substring of the function name used during
+// lowering to DXIL Operation calls. These class name strings are specified as
+// the third argument of add_dixil_op inutils/hct/hctdb.py. The function name
+// has the format "dx.op.<class-name>.<return-type>".
+
+defset list<DXILOpClass> OpClasses = {
+  def unary : DXILOpClass;
+  def binary : DXILOpClass;
+  def threadId : DXILOpClass;
+  def groupId : DXILOpClass;
+  def threadIdInGroup : DXILOpClass;
+  def flattenedThreadIdInGroup : DXILOpClass;
+}
+
 // LLVM intrinsic that DXIL operation maps to.
 class LLVMIntrinsic<Intrinsic llvm_intrinsic_> {
                                   Intrinsic llvm_intrinsic = llvm_intrinsic_;
                                  }
 
 // Abstraction DXIL Operation to LLVM intrinsic
-class DXILOpMapping<int opCode, Intrinsic intrinsic, string doc> {
+class DXILOpMapping<int opCode, DXILOpClass opClass, Intrinsic intrinsic, string doc> {
   int OpCode = opCode;                 // Opcode corresponding to DXIL Operation
+  DXILOpClass OpClass = opClass;             // Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
   string Doc = doc;                    // to a short description of the operation
 }
 
 // Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
-def Sin  : DXILOpMapping<13, int_sin,
+def Sin  : DXILOpMapping<13, unary, int_sin,
                          "Returns sine(theta) for theta in radians.">;
-def UMax : DXILOpMapping<39, int_umax,
+def UMax : DXILOpMapping<39, binary, int_umax,
                          "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
-def ThreadId : DXILOpMapping<93, int_dx_thread_id, "Reads the thread ID">;
-def GroupId  : DXILOpMapping<94, int_dx_group_id,
+def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
+                             "Reads the thread ID">;
+def GroupId  : DXILOpMapping<94, groupId, int_dx_group_id,
                              "Reads the group ID (SV_GroupID)">;
-def ThreadIdInGroup : DXILOpMapping<95, int_dx_thread_id_in_group,
+def ThreadIdInGroup : DXILOpMapping<95, threadIdInGroup,
+                                    int_dx_thread_id_in_group,
                                     "Reads the thread ID within the group "
                                     "(SV_GroupThreadID)">;
-def FlattenedThreadIdInGroup : DXILOpMapping<96,
+def FlattenedThreadIdInGroup : DXILOpMapping<96, flattenedThreadIdInGroup,
                                              int_dx_flattened_thread_id_in_group,
                                              "Provides a flattened index for a "
                                              "given thread within a given "
diff --git a/llvm/test/CodeGen/DirectX/comput_ids.ll b/llvm/test/CodeGen/DirectX/comput_ids.ll
index c0ae5761b4970e..553994094d71e5 100644
--- a/llvm/test/CodeGen/DirectX/comput_ids.ll
+++ b/llvm/test/CodeGen/DirectX/comput_ids.ll
@@ -9,7 +9,7 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; Function Attrs: noinline nounwind optnone
 define i32 @test_thread_id(i32 %a) #0 {
 entry:
-; CHECK:call i32 @dx.op.unary.i32(i32 93, i32 %{{.*}})
+; CHECK:call i32 @dx.op.threadId.i32(i32 93, i32 %{{.*}})
   %0 = call i32 @llvm.dx.thread.id(i32 %a)
   ret i32 %0
 }
@@ -18,7 +18,7 @@ entry:
 ; Function Attrs: noinline nounwind optnone
 define i32 @test_group_id(i32 %a) #0 {
 entry:
-; CHECK: call i32 @dx.op.unary.i32(i32 94, i32 %{{.*}})
+; CHECK: call i32 @dx.op.groupId.i32(i32 94, i32 %{{.*}})
   %0 = call i32 @llvm.dx.group.id(i32 %a)
   ret i32 %0
 }
@@ -27,7 +27,7 @@ entry:
 ; Function Attrs: noinline nounwind optnone
 define i32 @test_thread_id_in_group(i32 %a) #0 {
 entry:
-; CHECK: call i32 @dx.op.unary.i32(i32 95, i32 %{{.*}})
+; CHECK: call i32 @dx.op.threadIdInGroup.i32(i32 95, i32 %{{.*}})
   %0 = call i32 @llvm.dx.thread.id.in.group(i32 %a)
   ret i32 %0
 }
@@ -36,7 +36,7 @@ entry:
 ; Function Attrs: noinline nounwind optnone
 define i32 @test_flattened_thread_id_in_group() #0 {
 entry:
-; CHECK: call i32 @dx.op.nullary.i32(i32 96)
+; CHECK: call i32 @dx.op.flattenedThreadIdInGroup.i32(i32 96)
   %0 = call i32 @llvm.dx.flattened.thread.id.in.group()
   ret i32 %0
 }
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 1f7526fd82ddc3..1eca7cde0fb2a7 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -130,19 +130,9 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
         }
       }
     }
-    // Determine the operation class (unary/binary) based on the number of
-    // parameters As parameter types are being considered, skip return type
-    auto ParamSize = TypeListSize - 1;
-    if (ParamSize == 0) {
-      OpClass = "Nullary";
-    } else if (ParamSize == 1) {
-      OpClass = "Unary";
-    } else if (ParamSize == 2) {
-      OpClass = "Binary";
-    } else {
-      // TODO: Extend as needed
-      llvm_unreachable("Unhandled parameter size");
-    }
+    // Get the operation class
+    OpClass = R->getValueAsDef("OpClass")->getName();
+
     // NOTE: For now, assume that attributes of DXIL Operation are the same as
     // that of the intrinsic. Deviations are expected to be encoded in TableGen
     // record specification and handled accordingly here. Support to be added



More information about the llvm-commits mailing list