[llvm] [DirectX][NFC] Model precise overload type specification of DXIL Ops (PR #83917)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 12 11:45:50 PDT 2024


https://github.com/bharadwajy updated https://github.com/llvm/llvm-project/pull/83917

>From 9c49919b3a660153dd147a60a414710668cb6754 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Sat, 2 Mar 2024 21:19:53 -0500
Subject: [PATCH 1/6] [DirectX][NFC] Model precise overload type specification
 of DXIL Ops

Implements an abstraction to specify precise overload types supported by
DXIL ops. These overload types are typically a subset of LLVM
intrinsics.

Implements the corresponding changes in DXILEmitter backend.

Adds tests to verify expected errors for unsupported overload types
at code generation time.

Add tests to check for correct overload error output.
---
 llvm/lib/Target/DirectX/DXIL.td           |  49 ++++++-
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp |   2 +-
 llvm/test/CodeGen/DirectX/frac.ll         |   3 -
 llvm/test/CodeGen/DirectX/frac_error.ll   |  14 ++
 llvm/test/CodeGen/DirectX/round_error.ll  |  13 ++
 llvm/test/CodeGen/DirectX/sin.ll          |  19 +--
 llvm/test/CodeGen/DirectX/sin_error.ll    |  14 ++
 llvm/utils/TableGen/DXILEmitter.cpp       | 150 +++++++++++++++-------
 8 files changed, 195 insertions(+), 69 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/frac_error.ll
 create mode 100644 llvm/test/CodeGen/DirectX/round_error.ll
 create mode 100644 llvm/test/CodeGen/DirectX/sin_error.ll

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 9536a01e125bb3..24937d3d32755c 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -205,28 +205,67 @@ defset list<DXILOpClass> OpClasses = {
   def writeSamplerFeedbackBias : DXILOpClass;
   def writeSamplerFeedbackGrad : DXILOpClass;
   def writeSamplerFeedbackLevel: DXILOpClass;
+
+  // This is a sentinel definition. Hence placed at the end of the list
+  // and not as part of the above alphabetically sorted valid definitions.
+  // Additionally it is capitalized unlike all the others.
+  def UnknownOpClass: DXILOpClass;
+}
+
+// Several of the overloaded DXIL Operations support for data types
+// that are a subset of the overloaded LLVM intrinsics that they map to.
+// For e.g., llvm.sin.* intrinsic operates on any floating-point type and
+// maps for lowering to DXIL Op Sin. However, valid overloads of DXIL Sin
+// operation overloads are half (f16) and float (f32) only.
+//
+// The following abstracts overload types specific to DXIL operations.
+
+class DXILType : LLVMType<OtherVT> {
+  let isAny = 1;
 }
 
+// Concrete records for various overload types supported specifically by
+// DXIL Operations.
+
+def llvm_i16ori32_ty : DXILType;
+def llvm_halforfloat_ty : DXILType;
+
 // Abstraction DXIL Operation to LLVM intrinsic
-class DXILOpMapping<int opCode, DXILOpClass opClass, Intrinsic intrinsic, string doc> {
+class DXILOpMappingBase {
+  int OpCode = 0;                      // Opcode of DXIL Operation
+  DXILOpClass OpClass = UnknownOpClass;// Class of DXIL Operation.
+  Intrinsic LLVMIntrinsic = ?;         // LLVM Intrinsic DXIL Operation maps to
+  string Doc = "";                     // A short description of the operation
+  list<LLVMType> OpTypes = ?;          // Valid types of DXIL Operation in the
+                                       // format [returnTy, param1ty, ...]
+}
+
+class DXILOpMapping<int opCode, DXILOpClass opClass,
+                    Intrinsic intrinsic, string doc,
+                    list<LLVMType> opTys = []> : DXILOpMappingBase {
   int OpCode = opCode;                 // Opcode corresponding to DXIL Operation
-  DXILOpClass OpClass = opClass;             // Class of DXIL Operation.
+  DXILOpClass OpClass = opClass;       // Class of DXIL Operation.
   Intrinsic LLVMIntrinsic = intrinsic; // LLVM Intrinsic the DXIL Operation maps
   string Doc = doc;                    // to a short description of the operation
+  list<LLVMType> OpTypes = !if(!eq(!size(opTys), 0), LLVMIntrinsic.Types, opTys);
 }
 
 // Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
 def Sin  : DXILOpMapping<13, unary, int_sin,
-                         "Returns sine(theta) for theta in radians.">;
+                         "Returns sine(theta) for theta in radians.",
+                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
 def Exp2 : DXILOpMapping<21, unary, int_exp2,
                          "Returns the base 2 exponential, or 2**x, of the specified value."
                          "exp2(x) = 2**x.">;
+                         "Returns sine(theta) for theta in radians.">;
 def Frac : DXILOpMapping<22, unary, int_dx_frac,
                          "Returns a fraction from 0 to 1 that represents the "
-                         "decimal part of the input.">;
+                         "decimal part of the input.",
+                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
 def Round : DXILOpMapping<26, unary, int_round,
                          "Returns the input rounded to the nearest integer"
-                         "within a floating-point type.">;
+                         "within a floating-point type.",
+                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
 def UMax : DXILOpMapping<39, binary, int_umax,
                          "Unsigned integer maximum. UMax(a,b) = a > b ? a : b">;
 def FMad : DXILOpMapping<46, tertiary, int_fmuladd,
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 21a20d45b922d9..99bc5c214db298 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -257,7 +257,7 @@ static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
   // FIXME: find the issue and report error in clang instead of check it in
   // backend.
   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
-    llvm_unreachable("invalid overload");
+    report_fatal_error("Invalid Overload Type", false);
   }
 
   std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
diff --git a/llvm/test/CodeGen/DirectX/frac.ll b/llvm/test/CodeGen/DirectX/frac.ll
index ab605ed6084aa4..ae86fe06654da1 100644
--- a/llvm/test/CodeGen/DirectX/frac.ll
+++ b/llvm/test/CodeGen/DirectX/frac.ll
@@ -29,6 +29,3 @@ entry:
   %dx.frac = call half @llvm.dx.frac.f16(half %0)
   ret half %dx.frac
 }
-
-; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
-declare half @llvm.dx.frac.f16(half) #1
diff --git a/llvm/test/CodeGen/DirectX/frac_error.ll b/llvm/test/CodeGen/DirectX/frac_error.ll
new file mode 100644
index 00000000000000..ad51bb0b6dee71
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/frac_error.ll
@@ -0,0 +1,14 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; This test is expected to fail with the following error
+; CHECK: LLVM ERROR: Invalid Overload Type
+
+; Function Attrs: noinline nounwind optnone
+define noundef double @frac_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %dx.frac = call double @llvm.dx.frac.f64(double %0)
+  ret double %dx.frac
+}
diff --git a/llvm/test/CodeGen/DirectX/round_error.ll b/llvm/test/CodeGen/DirectX/round_error.ll
new file mode 100644
index 00000000000000..3bd87b2bbf0200
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/round_error.ll
@@ -0,0 +1,13 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; This test is expected to fail with the following error
+; CHECK: LLVM ERROR: Invalid Overload Type
+
+define noundef double @round_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %elt.round = call double @llvm.round.f64(double %0)
+  ret double %elt.round
+}
diff --git a/llvm/test/CodeGen/DirectX/sin.ll b/llvm/test/CodeGen/DirectX/sin.ll
index bb31d28bfcfee6..ea6f4d6200ee49 100644
--- a/llvm/test/CodeGen/DirectX/sin.ll
+++ b/llvm/test/CodeGen/DirectX/sin.ll
@@ -4,11 +4,8 @@
 ; CHECK:call float @dx.op.unary.f32(i32 13, float %{{.*}})
 ; CHECK:call half @dx.op.unary.f16(i32 13, half %{{.*}})
 
-target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
-target triple = "dxil-pc-shadermodel6.7-library"
-
 ; Function Attrs: noinline nounwind optnone
-define noundef float @_Z3foof(float noundef %a) #0 {
+define noundef float @sin_float(float noundef %a) #0 {
 entry:
   %a.addr = alloca float, align 4
   store float %a, ptr %a.addr, align 4
@@ -21,7 +18,7 @@ entry:
 declare float @llvm.sin.f32(float) #1
 
 ; Function Attrs: noinline nounwind optnone
-define noundef half @_Z3barDh(half noundef %a) #0 {
+define noundef half @sin_half(half noundef %a) #0 {
 entry:
   %a.addr = alloca half, align 2
   store half %a, ptr %a.addr, align 2
@@ -29,15 +26,3 @@ entry:
   %1 = call half @llvm.sin.f16(half %0)
   ret half %1
 }
-
-; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
-declare half @llvm.sin.f16(half) #1
-
-attributes #0 = { noinline nounwind optnone "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
-attributes #1 = { nocallback nofree nosync nounwind readnone speculatable willreturn }
-
-!llvm.module.flags = !{!0}
-!llvm.ident = !{!1}
-
-!0 = !{i32 1, !"wchar_size", i32 4}
-!1 = !{!"clang version 15.0.0 (https://github.com/llvm/llvm-project.git 73417c517644db5c419c85c0b3cb6750172fcab5)"}
diff --git a/llvm/test/CodeGen/DirectX/sin_error.ll b/llvm/test/CodeGen/DirectX/sin_error.ll
new file mode 100644
index 00000000000000..2a839faff00fe7
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/sin_error.ll
@@ -0,0 +1,14 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; This test is expected to fail with the following error
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @sin_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %1 = call double @llvm.sin.f64(double %0)
+  ret double %1
+}
+
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index fc958f5328736c..ac355464fe6ca5 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -22,6 +22,7 @@
 #include "llvm/Support/DXILABI.h"
 #include "llvm/TableGen/Record.h"
 #include "llvm/TableGen/TableGenBackend.h"
+#include <string>
 
 using namespace llvm;
 using namespace llvm::dxil;
@@ -38,8 +39,8 @@ struct DXILOperationDesc {
   int OpCode;         // ID of DXIL operation
   StringRef OpClass;  // name of the opcode class
   StringRef Doc;      // the documentation description of this instruction
-  SmallVector<MVT::SimpleValueType> OpTypes; // Vector of operand types -
-                                             // return type is at index 0
+  SmallVector<Record *> OpTypes; // Vector of operand type records -
+                                 // return type is at index 0
   SmallVector<std::string>
       OpAttributes;     // operation attribute represented as strings
   StringRef Intrinsic;  // The llvm intrinsic map to OpName. Default is "" which
@@ -57,20 +58,21 @@ struct DXILOperationDesc {
   DXILShaderModel ShaderModel;           // minimum shader model required
   DXILShaderModel ShaderModelTranslated; // minimum shader model required with
                                          // translation by linker
-  int OverloadParamIndex; // parameter index which control the overload.
-                          // When < 0, should be only 1 overload type.
+  int OverloadParamIndex;             // Index of parameter with overload type.
+                                      //   -1 : no overload types
   SmallVector<StringRef, 4> counters; // counters for this inst.
   DXILOperationDesc(const Record *);
 };
 } // end anonymous namespace
 
-/// Convert DXIL type name string to dxil::ParameterKind
+/// Return dxil::ParameterKind corresponding to input LLVMType record
 ///
-/// \param VT Simple Value Type
+/// \param R TableGen def record of class LLVMType
 /// \return ParameterKind As defined in llvm/Support/DXILABI.h
 
-static ParameterKind getParameterKind(MVT::SimpleValueType VT) {
-  switch (VT) {
+static ParameterKind getParameterKind(const Record *R) {
+  auto VTRec = R->getValueAsDef("VT");
+  switch (getValueType(VTRec)) {
   case MVT::isVoid:
     return ParameterKind::VOID;
   case MVT::f16:
@@ -90,6 +92,18 @@ static ParameterKind getParameterKind(MVT::SimpleValueType VT) {
   case MVT::fAny:
   case MVT::iAny:
     return ParameterKind::OVERLOAD;
+  case MVT::Other:
+    // Handle DXIL-specific overload types
+    {
+      auto RetKind = StringSwitch<ParameterKind>(R->getNameInitAsString())
+                         .Cases("llvm_i16ori32_ty", "llvm_halforfloat_ty",
+                                ParameterKind::OVERLOAD)
+                         .Default(ParameterKind::INVALID);
+      if (RetKind != ParameterKind::INVALID) {
+        return RetKind;
+      }
+    }
+    LLVM_FALLTHROUGH;
   default:
     llvm_unreachable("Support for specified DXIL Type not yet implemented");
   }
@@ -106,45 +120,80 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
 
   Doc = R->getValueAsString("Doc");
 
+  auto TypeRecs = R->getValueAsListOfDefs("OpTypes");
+  unsigned TypeRecsSize = TypeRecs.size();
+  // Populate OpTypes with return type and parameter types
+
+  // Parameter indices of overloaded parameters.
+  // This vector contains overload parameters in the order order used to
+  // resolve an LLVMMatchType in accordance with  convention outlined in
+  // the comment before the definition of class LLVMMatchType in
+  // llvm/IR/Intrinsics.td
+  SmallVector<int> OverloadParamIndices;
+  for (unsigned i = 0; i < TypeRecsSize; i++) {
+    auto TR = TypeRecs[i];
+    // Track operation parameter indices of any overload types
+    auto isAny = TR->getValueAsInt("isAny");
+    if (isAny == 1) {
+      // TODO: At present it is expected that all overload types in a DXIL Op
+      // are of the same type. Hence, OverloadParamIndices will have only one
+      // element. This implies we do not need a vector. However, until more
+      // (all?) DXIL Ops are added in DXIL.td, a vector is being used to flag
+      // cases this assumption would not hold.
+      if (!OverloadParamIndices.empty()) {
+        bool knownType = true;
+        // Ensure that the same overload type registered earlier is being used
+        for (auto Idx : OverloadParamIndices) {
+          if (TR != TypeRecs[Idx]) {
+            knownType = false;
+            break;
+          }
+        }
+        if (!knownType) {
+          report_fatal_error("Specification of multiple differing overload "
+                             "parameter types not yet supported",
+                             false);
+        }
+      } else {
+        OverloadParamIndices.push_back(i);
+      }
+    }
+    // Populate OpTypes array according to the type specification
+    if (TR->isAnonymous()) {
+      // Check prior overload types exist
+      assert(!OverloadParamIndices.empty() &&
+             "No prior overloaded parameter found to match.");
+      // Get the parameter index of anonymous type, TR, references
+      auto OLParamIndex = TR->getValueAsInt("Number");
+      // Resolve and insert the type to that at OLParamIndex
+      OpTypes.emplace_back(TypeRecs[OLParamIndex]);
+    } else {
+      // A non-anonymous type. Just record it in OpTypes
+      OpTypes.emplace_back(TR);
+    }
+  }
+
+  // Set the index of the overload parameter, if any.
+  OverloadParamIndex = -1; // default; indicating none
+  if (!OverloadParamIndices.empty()) {
+    if (OverloadParamIndices.size() > 1)
+      report_fatal_error("Multiple overload type specification not supported",
+                         false);
+    OverloadParamIndex = OverloadParamIndices[0];
+  }
+  // Get the operation class
+  OpClass = R->getValueAsDef("OpClass")->getName();
+
   if (R->getValue("LLVMIntrinsic")) {
     auto *IntrinsicDef = R->getValueAsDef("LLVMIntrinsic");
     auto DefName = IntrinsicDef->getName();
     assert(DefName.starts_with("int_") && "invalid intrinsic name");
     // Remove the int_ from intrinsic name.
     Intrinsic = DefName.substr(4);
-    // TODO: It is expected that return type and parameter types of
-    // DXIL Operation are the same as that of the intrinsic. Deviations
-    // are expected to be encoded in TableGen record specification and
-    // handled accordingly here. Support to be added later, as needed.
-    // Get parameter type list of the intrinsic. Types attribute contains
-    // the list of as [returnType, param1Type,, param2Type, ...]
-
-    OverloadParamIndex = -1;
-    auto TypeRecs = IntrinsicDef->getValueAsListOfDefs("Types");
-    unsigned TypeRecsSize = TypeRecs.size();
-    // Populate return type and parameter type names
-    for (unsigned i = 0; i < TypeRecsSize; i++) {
-      auto TR = TypeRecs[i];
-      OpTypes.emplace_back(getValueType(TR->getValueAsDef("VT")));
-      // Get the overload parameter index.
-      // TODO : Seems hacky. Is it possible that more than one parameter can
-      // be of overload kind??
-      // TODO: Check for any additional constraints specified for DXIL operation
-      // restricting return type.
-      if (i > 0) {
-        auto &CurParam = OpTypes.back();
-        if (getParameterKind(CurParam) >= ParameterKind::OVERLOAD) {
-          OverloadParamIndex = i;
-        }
-      }
-    }
-    // Get the operation class
-    OpClass = R->getValueAsDef("OpClass")->getName();
-
-    // NOTE: For now, assume that attributes of DXIL Operation are the same as
+    // TODO: For now, assume that attributes of DXIL Operation are the same as
     // that of the intrinsic. Deviations are expected to be encoded in TableGen
     // record specification and handled accordingly here. Support to be added
-    // later.
+    // as needed.
     auto IntrPropList = IntrinsicDef->getValueAsListInit("IntrProperties");
     auto IntrPropListSize = IntrPropList->size();
     for (unsigned i = 0; i < IntrPropListSize; i++) {
@@ -191,12 +240,13 @@ static std::string getParameterKindStr(ParameterKind Kind) {
 }
 
 /// Return a string representation of OverloadKind enum that maps to
-/// input Simple Value Type enum
-/// \param VT Simple Value Type enum
+/// input LLVMType record
+/// \param R TableGen def record of class LLVMType
 /// \return std::string string representation of OverloadKind
 
-static std::string getOverloadKindStr(MVT::SimpleValueType VT) {
-  switch (VT) {
+static std::string getOverloadKindStr(const Record *R) {
+  auto VTRec = R->getValueAsDef("VT");
+  switch (getValueType(VTRec)) {
   case MVT::isVoid:
     return "OverloadKind::VOID";
   case MVT::f16:
@@ -219,6 +269,20 @@ static std::string getOverloadKindStr(MVT::SimpleValueType VT) {
     return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64";
   case MVT::fAny:
     return "OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE";
+  case MVT::Other:
+    // Handle DXIL-specific overload types
+    {
+      auto RetStr =
+          StringSwitch<std::string>(R->getNameInitAsString())
+              .Case("llvm_i16ori32_ty", "OverloadKind::I16 | OverloadKind::I32")
+              .Case("llvm_halforfloat_ty",
+                    "OverloadKind::HALF | OverloadKind::FLOAT")
+              .Default("");
+      if (RetStr != "") {
+        return RetStr;
+      }
+    }
+    LLVM_FALLTHROUGH;
   default:
     llvm_unreachable(
         "Support for specified parameter OverloadKind not yet implemented");

>From 47078fce2afa713f14a9db52a2462792bc29bfce Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Tue, 5 Mar 2024 13:28:52 -0500
Subject: [PATCH 2/6] Delete spurious code - per PR review feedback

---
 llvm/test/CodeGen/DirectX/sin.ll | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/llvm/test/CodeGen/DirectX/sin.ll b/llvm/test/CodeGen/DirectX/sin.ll
index ea6f4d6200ee49..1f285c433581cf 100644
--- a/llvm/test/CodeGen/DirectX/sin.ll
+++ b/llvm/test/CodeGen/DirectX/sin.ll
@@ -14,9 +14,6 @@ entry:
   ret float %1
 }
 
-; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn
-declare float @llvm.sin.f32(float) #1
-
 ; Function Attrs: noinline nounwind optnone
 define noundef half @sin_half(half noundef %a) #0 {
 entry:

>From 9a231f508ddd96a18ad5abfcb5ce68c9ecd85c6b Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Thu, 7 Mar 2024 14:05:09 -0500
Subject: [PATCH 3/6] Address PR review comments.  - Improve comments.  - Add
 fields to DXILType to distinguish types being represented.    This allows for
 cleaner identification of the type.

---
 llvm/lib/Target/DirectX/DXIL.td           |  9 ++++++---
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp |  2 +-
 llvm/test/CodeGen/DirectX/frac_error.ll   |  2 +-
 llvm/test/CodeGen/DirectX/sin_error.ll    |  2 +-
 llvm/utils/TableGen/DXILEmitter.cpp       | 12 ++++--------
 5 files changed, 13 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 24937d3d32755c..ee73f7514382d4 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -222,13 +222,17 @@ defset list<DXILOpClass> OpClasses = {
 
 class DXILType : LLVMType<OtherVT> {
   let isAny = 1;
+  int isI16OrI32 = 0;
+  int isHalfOrFloat = 0;
 }
 
 // Concrete records for various overload types supported specifically by
 // DXIL Operations.
+let isI16OrI32 = 1 in
+  def llvm_i16ori32_ty : DXILType;
 
-def llvm_i16ori32_ty : DXILType;
-def llvm_halforfloat_ty : DXILType;
+let isHalfOrFloat = 1 in
+  def llvm_halforfloat_ty : DXILType;
 
 // Abstraction DXIL Operation to LLVM intrinsic
 class DXILOpMappingBase {
@@ -257,7 +261,6 @@ def Sin  : DXILOpMapping<13, unary, int_sin,
 def Exp2 : DXILOpMapping<21, unary, int_exp2,
                          "Returns the base 2 exponential, or 2**x, of the specified value."
                          "exp2(x) = 2**x.">;
-                         "Returns sine(theta) for theta in radians.">;
 def Frac : DXILOpMapping<22, unary, int_dx_frac,
                          "Returns a fraction from 0 to 1 that represents the "
                          "decimal part of the input.",
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 99bc5c214db298..e7aeda051fab48 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -257,7 +257,7 @@ static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
   // FIXME: find the issue and report error in clang instead of check it in
   // backend.
   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
-    report_fatal_error("Invalid Overload Type", false);
+    report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);
   }
 
   std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
diff --git a/llvm/test/CodeGen/DirectX/frac_error.ll b/llvm/test/CodeGen/DirectX/frac_error.ll
index ad51bb0b6dee71..ebce76105ad4d7 100644
--- a/llvm/test/CodeGen/DirectX/frac_error.ll
+++ b/llvm/test/CodeGen/DirectX/frac_error.ll
@@ -1,6 +1,6 @@
 ; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
 
-; This test is expected to fail with the following error
+; DXIL operation frac does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload Type
 
 ; Function Attrs: noinline nounwind optnone
diff --git a/llvm/test/CodeGen/DirectX/sin_error.ll b/llvm/test/CodeGen/DirectX/sin_error.ll
index 2a839faff00fe7..ece0e530315b2f 100644
--- a/llvm/test/CodeGen/DirectX/sin_error.ll
+++ b/llvm/test/CodeGen/DirectX/sin_error.ll
@@ -1,6 +1,6 @@
 ; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
 
-; This test is expected to fail with the following error
+; DXIL operation sin does not support double overload type
 ; CHECK: LLVM ERROR: Invalid Overload
 
 define noundef double @sin_double(double noundef %a) #0 {
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index ac355464fe6ca5..2f4363673320be 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -272,14 +272,10 @@ static std::string getOverloadKindStr(const Record *R) {
   case MVT::Other:
     // Handle DXIL-specific overload types
     {
-      auto RetStr =
-          StringSwitch<std::string>(R->getNameInitAsString())
-              .Case("llvm_i16ori32_ty", "OverloadKind::I16 | OverloadKind::I32")
-              .Case("llvm_halforfloat_ty",
-                    "OverloadKind::HALF | OverloadKind::FLOAT")
-              .Default("");
-      if (RetStr != "") {
-        return RetStr;
+      if (R->getValueAsInt("isHalfOrFloat")) {
+        return "OverloadKind::HALF | OverloadKind::FLOAT";
+      } else if (R->getValueAsInt("isI16OrI32")) {
+        return "OverloadKind::I16 | OverloadKind::I32";
       }
     }
     LLVM_FALLTHROUGH;

>From 27563e62233c3133e9083d381c5e90e0d58b81ee Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Mon, 11 Mar 2024 12:30:20 -0400
Subject: [PATCH 4/6] [DirectX] Specify overload type of DXIL Op exp2 precisely
 Add a test to verify error generation for exp2 with overload type double.

Use type property instead of record name string to identify precise type.
---
 llvm/lib/Target/DirectX/DXIL.td         |  3 ++-
 llvm/test/CodeGen/DirectX/exp2_error.ll | 13 +++++++++++++
 llvm/utils/TableGen/DXILEmitter.cpp     |  9 +++------
 3 files changed, 18 insertions(+), 7 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/exp2_error.ll

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index ee73f7514382d4..66b0ef24332c25 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -260,7 +260,8 @@ def Sin  : DXILOpMapping<13, unary, int_sin,
                          [llvm_halforfloat_ty, LLVMMatchType<0>]>;
 def Exp2 : DXILOpMapping<21, unary, int_exp2,
                          "Returns the base 2 exponential, or 2**x, of the specified value."
-                         "exp2(x) = 2**x.">;
+                         "exp2(x) = 2**x.",
+                         [llvm_halforfloat_ty, LLVMMatchType<0>]>;
 def Frac : DXILOpMapping<22, unary, int_dx_frac,
                          "Returns a fraction from 0 to 1 that represents the "
                          "decimal part of the input.",
diff --git a/llvm/test/CodeGen/DirectX/exp2_error.ll b/llvm/test/CodeGen/DirectX/exp2_error.ll
new file mode 100644
index 00000000000000..6b9126785fd4b8
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/exp2_error.ll
@@ -0,0 +1,13 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation exp2 does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload
+
+define noundef double @exp2_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %elt.exp2 = call double @llvm.exp2.f64(double %0)
+  ret double %elt.exp2
+}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 2f4363673320be..7b964b7a4fccf5 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -95,12 +95,9 @@ static ParameterKind getParameterKind(const Record *R) {
   case MVT::Other:
     // Handle DXIL-specific overload types
     {
-      auto RetKind = StringSwitch<ParameterKind>(R->getNameInitAsString())
-                         .Cases("llvm_i16ori32_ty", "llvm_halforfloat_ty",
-                                ParameterKind::OVERLOAD)
-                         .Default(ParameterKind::INVALID);
-      if (RetKind != ParameterKind::INVALID) {
-        return RetKind;
+      if ((R->getValueAsInt("isHalfOrFloat")) ||
+          (R->getValueAsInt("isI16OrI32"))) {
+        return ParameterKind::OVERLOAD;
       }
     }
     LLVM_FALLTHROUGH;

>From 006b0ab355a1bbb2bffaea3d889f12d9192cba0e Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Tue, 12 Mar 2024 12:11:39 -0400
Subject: [PATCH 5/6] Delete outdated comment

---
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index e7aeda051fab48..11b24d04492368 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -254,8 +254,6 @@ static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
   const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
 
   OverloadKind Kind = getOverloadKind(OverloadTy);
-  // FIXME: find the issue and report error in clang instead of check it in
-  // backend.
   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
     report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);
   }

>From 7485a4c894d3cde85c6775b22e2e1ac4c88f168f Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Tue, 12 Mar 2024 14:44:56 -0400
Subject: [PATCH 6/6] Address feedback on extra scope and parantheses.

---
 llvm/utils/TableGen/DXILEmitter.cpp | 7 ++-----
 1 file changed, 2 insertions(+), 5 deletions(-)

diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 7b964b7a4fccf5..59089929837ebb 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -94,11 +94,8 @@ static ParameterKind getParameterKind(const Record *R) {
     return ParameterKind::OVERLOAD;
   case MVT::Other:
     // Handle DXIL-specific overload types
-    {
-      if ((R->getValueAsInt("isHalfOrFloat")) ||
-          (R->getValueAsInt("isI16OrI32"))) {
-        return ParameterKind::OVERLOAD;
-      }
+    if (R->getValueAsInt("isHalfOrFloat") || R->getValueAsInt("isI16OrI32")) {
+      return ParameterKind::OVERLOAD;
     }
     LLVM_FALLTHROUGH;
   default:



More information about the llvm-commits mailing list