[llvm] [DirectX][NFC] Model precise overload type specification of DXIL Ops (PR #83917)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 4 14:08:34 PST 2024


https://github.com/bharadwajy created https://github.com/llvm/llvm-project/pull/83917

Implements an abstraction to specify precise overload types supported by DXIL ops. These overload types are typically a subset of LLVM intrinsics.

Implements the corresponding changes in DXILEmitter backend.

Adds tests to verify expected errors for unsupported overload types at code generation time.

Add tests to check for correct overrload error output.

>From ecc20803edbd7cf5a3dbb15339122bfd5bbad14b Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Sat, 2 Mar 2024 21:19:53 -0500
Subject: [PATCH] [DirectX][NFC] Model precise overload type specification of
 DXIL Ops

Implements an abstraction to specify precise overload types supported by
DXIL ops. These overload types are typically a subset of LLVM
intrinsics.

Implements the corresponding changes in DXILEmitter backend.

Adds tests to verify expected errors for unsupported overload types
at code generation time.

Add tests to check for correct overrload error output.
---
 llvm/lib/Target/DirectX/DXIL.td           |  48 ++++++-
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp |   2 +-
 llvm/test/CodeGen/DirectX/frac.ll         |   3 -
 llvm/test/CodeGen/DirectX/frac_error.ll   |  14 ++
 llvm/test/CodeGen/DirectX/round_error.ll  |  13 ++
 llvm/test/CodeGen/DirectX/sin.ll          |  19 +--
 llvm/test/CodeGen/DirectX/sin_error.ll    |  14 ++
 llvm/utils/TableGen/DXILEmitter.cpp       | 150 +++++++++++++++-------
 8 files changed, 194 insertions(+), 69 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/frac_error.ll
 create mode 100644 llvm/test/CodeGen/DirectX/round_error.ll
 create mode 100644 llvm/test/CodeGen/DirectX/sin_error.ll

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 33b08ed93e3d0a..762664176048e4 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -205,25 +205,63 @@ defset list<DXILOpClass> OpClasses = {
   def writeSamplerFeedbackBias : DXILOpClass;
   def writeSamplerFeedbackGrad : DXILOpClass;
   def writeSamplerFeedbackLevel: DXILOpClass;
+
+  // This is a sentinel definition. Hence placed at the end of the list
+  // and not as part of the above alphabetically sorted valid definitions.
+  // Additionally it is capitalized unlike all the others.
+  def UnknownOpClass: DXILOpClass;
+}
+
+// Several of the overloaded DXIL Operations support for data types
+// that are a subset of the overloaded LLVM intrinsics that they map to.
+// For e.g., llvm.sin.* intrinsic operates on any floating-point type and
+// maps for lowering to DXIL Op Sin. However, valid overloads of DXIL Sin
+// operation overloads are half (f16) and float (f32) only.
+//
+// The following abstracts overload types specific to DXIL operations.
+
+class DXILType : LLVMType<OtherVT> {
+  let isAny = 1;
 }
 
+// Concrete records for various overload types supported specifically by
+// DXIL Operations.
+
+def llvm_i16ori32_ty : DXILType;
+def llvm_halforfloat_ty : DXILType;
+
 // Abstraction DXIL Operation to LLVM intrinsic
-class DXILOpMapping<int opCode, DXILOpClass opClass, Intrinsic intrinsic, string doc> {
+class DXILOpMappingBase {
+  int OpCode = 0;                      // Opcode of DXIL Operation
+  DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
+  Intrinsic LLVMIntrinsic = ?;         // LLVM Intrinsic DXIL Operation maps to
+  string Doc = "";                     // A short description of the operation
+  list<LLVMType> OpTypes = ?;          // Valid types of DXIL Operation in the
+                                       // format [returnTy, param1ty, ...]
+}
+
+class DXILOpMapping<int opCode, DXILOpClass opClass,
+                    Intrinsic intrinsic, string doc,
+                    list<LLVMType> opTys = []> : DXILOpMappingBase {
   int OpCode = opCode;                 // Opcode corresponding to DXIL Operation
-  DXILOpClass OpClass = opClass;             // Class of DXIL Operation.
+  DXILOpClass OpClass = opClass;       // Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
   string Doc = doc;                    // to a short description of the operation
+  list<LLVMType> OpTypes = !if(!eq(!size(opTys), 0), LLVMIntrinsic.Types, opTys);
 }
 
 // Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
 def Sin  : DXILOpMapping<13, unary, int_sin,
-                         "Returns sine(theta) for theta in radians.">;
+                         "Returns sine(theta) for theta in radians.",
+                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
 def Frac : DXILOpMapping<22, unary, int_dx_frac,
                          "Returns a fraction from 0 to 1 that represents the "
-                         "decimal part of the input.">;
+                         "decimal part of the input.",
+                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
 def Round : DXILOpMapping<26, unary, int_round,
                          "Returns the input rounded to the nearest integer"
-                         "within a floating-point type.">;
+                         "within a floating-point type.",
+                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
 def UMax : DXILOpMapping<39, binary, int_umax,
                          "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
 def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id,
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 21a20d45b922d9..99bc5c214db298 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -257,7 +257,7 @@ static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
   // FIXME: find the issue and report error in clang instead of check it in
   // backend.
   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
-    llvm_unreachable("invalid overload");
+    report_fatal_error("Invalid Overload Type", false);
   }
 
   std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
diff --git a/llvm/test/CodeGen/DirectX/frac.ll b/llvm/test/CodeGen/DirectX/frac.ll
index ab605ed6084aa4..ae86fe06654da1 100644
--- a/llvm/test/CodeGen/DirectX/frac.ll
+++ b/llvm/test/CodeGen/DirectX/frac.ll
@@ -29,6 +29,3 @@ entry:
   %dx.frac = call half @llvm.dx.frac.f16(half %0)
   ret half %dx.frac
 }
-
-; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
-declare half @llvm.dx.frac.f16(half) #1
diff --git a/llvm/test/CodeGen/DirectX/frac_error.ll b/llvm/test/CodeGen/DirectX/frac_error.ll
new file mode 100644
index 00000000000000..ad51bb0b6dee71
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/frac_error.ll
@@ -0,0 +1,14 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; This test is expected to fail with the following error
+; CHECK: LLVM ERROR: Invalid Overload Type
+
+; Function Attrs: noinline nounwind optnone
+define noundef double @frac_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %dx.frac = call double @llvm.dx.frac.f64(double %0)
+  ret double %dx.frac
+}
diff --git a/llvm/test/CodeGen/DirectX/round_error.ll b/llvm/test/CodeGen/DirectX/round_error.ll
new file mode 100644
index 00000000000000..3bd87b2bbf0200
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/round_error.ll
@@ -0,0 +1,13 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; This test is expected to fail with the following error
+; CHECK: LLVM ERROR: Invalid Overload Type
+
+define noundef double @round_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %elt.round = call double @llvm.round.f64(double %0)
+  ret double %elt.round
+}
diff --git a/llvm/test/CodeGen/DirectX/sin.ll b/llvm/test/CodeGen/DirectX/sin.ll
index bb31d28bfcfee6..ea6f4d6200ee49 100644
--- a/llvm/test/CodeGen/DirectX/sin.ll
+++ b/llvm/test/CodeGen/DirectX/sin.ll
@@ -4,11 +4,8 @@
 ; CHECK:call float @dx.op.unary.f32(i32 13, float %{{.*}})
 ; CHECK:call half @dx.op.unary.f16(i32 13, half %{{.*}})
 
-target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
-target triple = "dxil-pc-shadermodel6.7-library"
-
 ; Function Attrs: noinline nounwind optnone
-define noundef float @_Z3foof(float noundef %a) #0 {
+define noundef float @sin_float(float noundef %a) #0 {
 entry:
   %a.addr = alloca float, align 4
   store float %a, ptr %a.addr, align 4
@@ -21,7 +18,7 @@ entry:
 declare float @llvm.sin.f32(float) #1
 
 ; Function Attrs: noinline nounwind optnone
-define noundef half @_Z3barDh(half noundef %a) #0 {
+define noundef half @sin_half(half noundef %a) #0 {
 entry:
   %a.addr = alloca half, align 2
   store half %a, ptr %a.addr, align 2
@@ -29,15 +26,3 @@ entry:
   %1 = call half @llvm.sin.f16(half %0)
   ret half %1
 }
-
-; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
-declare half @llvm.sin.f16(half) #1
-
-attributes #0 = { noinline nounwind optnone "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
-attributes #1 = { nocallback nofree nosync nounwind readnone speculatable willreturn }
-
-!llvm.module.flags = !{!0}
-!llvm.ident = !{!1}
-
-!0 = !{i32 1, !"wchar_size", i32 4}
-!1 = !{!"clang version 15.0.0 (https://github.com/llvm/llvm-project.git 73417c517644db5c419c85c0b3cb6750172fcab5)"}
diff --git a/llvm/test/CodeGen/DirectX/sin_error.ll b/llvm/test/CodeGen/DirectX/sin_error.ll
new file mode 100644
index 00000000000000..2a839faff00fe7
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/sin_error.ll
@@ -0,0 +1,14 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; This test is expected to fail with the following error
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @sin_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %1 = call double @llvm.sin.f64(double %0)
+  ret double %1
+}
+
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index fc958f5328736c..ac355464fe6ca5 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -22,6 +22,7 @@
 #include "llvm/Support/DXILABI.h"
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
+#include <string>
 
 using namespace llvm;
 using namespace llvm::dxil;
@@ -38,8 +39,8 @@ struct DXILOperationDesc {
   int OpCode;         // ID of DXIL operation
   StringRef OpClass;  // name of the opcode class
   StringRef Doc;      // the documentation description of this instruction
-  SmallVector<MVT::SimpleValueType> OpTypes; // Vector of operand types -
-                                             // return type is at index 0
+  SmallVector<Record *> OpTypes; // Vector of operand type records -
+                                 // return type is at index 0
   SmallVector<std::string>
       OpAttributes;     // operation attribute represented as strings
   StringRef Intrinsic;  // The llvm intrinsic map to OpName. Default is "" which
@@ -57,20 +58,21 @@ struct DXILOperationDesc {
   DXILShaderModel ShaderModel;           // minimum shader model required
   DXILShaderModel ShaderModelTranslated; // minimum shader model required with
                                          // translation by linker
-  int OverloadParamIndex; // parameter index which control the overload.
-                          // When < 0, should be only 1 overload type.
+  int OverloadParamIndex;             // Index of parameter with overload type.
+                                      //   -1 : no overload types
   SmallVector<StringRef, 4> counters; // counters for this inst.
   DXILOperationDesc(const Record *);
 };
 } // end anonymous namespace
 
-/// Convert DXIL type name string to dxil::ParameterKind
+/// Return dxil::ParameterKind corresponding to input LLVMType record
 ///
-/// \param VT Simple Value Type
+/// \param R TableGen def record of class LLVMType
 /// \return ParameterKind As defined in llvm/Support/DXILABI.h
 
-static ParameterKind getParameterKind(MVT::SimpleValueType VT) {
-  switch (VT) {
+static ParameterKind getParameterKind(const Record *R) {
+  auto VTRec = R->getValueAsDef("VT");
+  switch (getValueType(VTRec)) {
   case MVT::isVoid:
     return ParameterKind::VOID;
   case MVT::f16:
@@ -90,6 +92,18 @@ static ParameterKind getParameterKind(MVT::SimpleValueType VT) {
   case MVT::fAny:
   case MVT::iAny:
     return ParameterKind::OVERLOAD;
+  case MVT::Other:
+    // Handle DXIL-specific overload types
+    {
+      auto RetKind = StringSwitch<ParameterKind>(R->getNameInitAsString())
+                         .Cases("llvm_i16ori32_ty", "llvm_halforfloat_ty",
+                                ParameterKind::OVERLOAD)
+                         .Default(ParameterKind::INVALID);
+      if (RetKind != ParameterKind::INVALID) {
+        return RetKind;
+      }
+    }
+    LLVM_FALLTHROUGH;
   default:
     llvm_unreachable("Support for specified DXIL Type not yet implemented");
   }
@@ -106,45 +120,80 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
 
   Doc = R->getValueAsString("Doc");
 
+  auto TypeRecs = R->getValueAsListOfDefs("OpTypes");
+  unsigned TypeRecsSize = TypeRecs.size();
+  // Populate OpTypes with return type and parameter types
+
+  // Parameter indices of overloaded parameters.
+  // This vector contains overload parameters in the order order used to
+  // resolve an LLVMMatchType in accordance with  convention outlined in
+  // the comment before the definition of class LLVMMatchType in
+  // llvm/IR/Intrinsics.td
+  SmallVector<int> OverloadParamIndices;
+  for (unsigned i = 0; i < TypeRecsSize; i++) {
+    auto TR = TypeRecs[i];
+    // Track operation parameter indices of any overload types
+    auto isAny = TR->getValueAsInt("isAny");
+    if (isAny == 1) {
+      // TODO: At present it is expected that all overload types in a DXIL Op
+      // are of the same type. Hence, OverloadParamIndices will have only one
+      // element. This implies we do not need a vector. However, until more
+      // (all?) DXIL Ops are added in DXIL.td, a vector is being used to flag
+      // cases this assumption would not hold.
+      if (!OverloadParamIndices.empty()) {
+        bool knownType = true;
+        // Ensure that the same overload type registered earlier is being used
+        for (auto Idx : OverloadParamIndices) {
+          if (TR != TypeRecs[Idx]) {
+            knownType = false;
+            break;
+          }
+        }
+        if (!knownType) {
+          report_fatal_error("Specification of multiple differing overload "
+                             "parameter types not yet supported",
+                             false);
+        }
+      } else {
+        OverloadParamIndices.push_back(i);
+      }
+    }
+    // Populate OpTypes array according to the type specification
+    if (TR->isAnonymous()) {
+      // Check prior overload types exist
+      assert(!OverloadParamIndices.empty() &&
+             "No prior overloaded parameter found to match.");
+      // Get the parameter index of anonymous type, TR, references
+      auto OLParamIndex = TR->getValueAsInt("Number");
+      // Resolve and insert the type to that at OLParamIndex
+      OpTypes.emplace_back(TypeRecs[OLParamIndex]);
+    } else {
+      // A non-anonymous type. Just record it in OpTypes
+      OpTypes.emplace_back(TR);
+    }
+  }
+
+  // Set the index of the overload parameter, if any.
+  OverloadParamIndex = -1; // default; indicating none
+  if (!OverloadParamIndices.empty()) {
+    if (OverloadParamIndices.size() > 1)
+      report_fatal_error("Multiple overload type specification not supported",
+                         false);
+    OverloadParamIndex = OverloadParamIndices[0];
+  }
+  // Get the operation class
+  OpClass = R->getValueAsDef("OpClass")->getName();
+
   if (R->getValue("LLVMIntrinsic")) {
     auto *IntrinsicDef = R->getValueAsDef("LLVMIntrinsic");
     auto DefName = IntrinsicDef->getName();
     assert(DefName.starts_with("int_") && "invalid intrinsic name");
     // Remove the int_ from intrinsic name.
     Intrinsic = DefName.substr(4);
-    // TODO: It is expected that return type and parameter types of
-    // DXIL Operation are the same as that of the intrinsic. Deviations
-    // are expected to be encoded in TableGen record specification and
-    // handled accordingly here. Support to be added later, as needed.
-    // Get parameter type list of the intrinsic. Types attribute contains
-    // the list of as [returnType, param1Type,, param2Type, ...]
-
-    OverloadParamIndex = -1;
-    auto TypeRecs = IntrinsicDef->getValueAsListOfDefs("Types");
-    unsigned TypeRecsSize = TypeRecs.size();
-    // Populate return type and parameter type names
-    for (unsigned i = 0; i < TypeRecsSize; i++) {
-      auto TR = TypeRecs[i];
-      OpTypes.emplace_back(getValueType(TR->getValueAsDef("VT")));
-      // Get the overload parameter index.
-      // TODO : Seems hacky. Is it possible that more than one parameter can
-      // be of overload kind??
-      // TODO: Check for any additional constraints specified for DXIL operation
-      // restricting return type.
-      if (i > 0) {
-        auto &CurParam = OpTypes.back();
-        if (getParameterKind(CurParam) >= ParameterKind::OVERLOAD) {
-          OverloadParamIndex = i;
-        }
-      }
-    }
-    // Get the operation class
-    OpClass = R->getValueAsDef("OpClass")->getName();
-
-    // NOTE: For now, assume that attributes of DXIL Operation are the same as
+    // TODO: For now, assume that attributes of DXIL Operation are the same as
     // that of the intrinsic. Deviations are expected to be encoded in TableGen
     // record specification and handled accordingly here. Support to be added
-    // later.
+    // as needed.
     auto IntrPropList = IntrinsicDef->getValueAsListInit("IntrProperties");
     auto IntrPropListSize = IntrPropList->size();
     for (unsigned i = 0; i < IntrPropListSize; i++) {
@@ -191,12 +240,13 @@ static std::string getParameterKindStr(ParameterKind Kind) {
 }
 
 /// Return a string representation of OverloadKind enum that maps to
-/// input Simple Value Type enum
-/// \param VT Simple Value Type enum
+/// input LLVMType record
+/// \param R TableGen def record of class LLVMType
 /// \return std::string string representation of OverloadKind
 
-static std::string getOverloadKindStr(MVT::SimpleValueType VT) {
-  switch (VT) {
+static std::string getOverloadKindStr(const Record *R) {
+  auto VTRec = R->getValueAsDef("VT");
+  switch (getValueType(VTRec)) {
   case MVT::isVoid:
     return "OverloadKind::VOID";
   case MVT::f16:
@@ -219,6 +269,20 @@ static std::string getOverloadKindStr(MVT::SimpleValueType VT) {
     return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64";
   case MVT::fAny:
     return "OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE";
+  case MVT::Other:
+    // Handle DXIL-specific overload types
+    {
+      auto RetStr =
+          StringSwitch<std::string>(R->getNameInitAsString())
+              .Case("llvm_i16ori32_ty", "OverloadKind::I16 | OverloadKind::I32")
+              .Case("llvm_halforfloat_ty",
+                    "OverloadKind::HALF | OverloadKind::FLOAT")
+              .Default("");
+      if (RetStr != "") {
+        return RetStr;
+      }
+    }
+    LLVM_FALLTHROUGH;
   default:
     llvm_unreachable(
         "Support for specified parameter OverloadKind not yet implemented");



More information about the llvm-commits mailing list