[llvm-branch-commits] [llvm] [DirectX] Make DXILOpBuilder's API more useable (PR #101250)

Justin Bogner via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jul 30 23:15:57 PDT 2024


https://github.com/bogner updated https://github.com/llvm/llvm-project/pull/101250

>From bb8ed818abc30065713e9a21863624ad9b6c4820 Mon Sep 17 00:00:00 2001
From: Justin Bogner <mail at justinbogner.com>
Date: Tue, 30 Jul 2024 15:56:35 -0700
Subject: [PATCH 1/2] Improve comment

Created using spr 1.3.5-bogner
---
 llvm/utils/TableGen/DXILEmitter.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 0c2b9ec69d82e..66b6206fedea5 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -151,9 +151,10 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
         assert(knownType && "Specification of multiple differing overload "
                             "parameter types not yet supported");
       } else {
-        // Skip the return value - nothing is overloaded on only return, and it
-        // makes it harder to determine the overload from an argument list
-        // later.
+        // Skip the return value - having the overload index point at the return
+        // value makes it hard to determine the overload from an argument list,
+        // and we treat unoverloaded functions and those overloaded on return
+        // identically anyway.
         if (i != 0)
           OverloadParamIndices.push_back(i);
       }

>From 1b50eb18604db7396e96f23e4b36f0598fdb2592 Mon Sep 17 00:00:00 2001
From: Justin Bogner <mail at justinbogner.com>
Date: Tue, 30 Jul 2024 19:14:23 -0700
Subject: [PATCH 2/2] Handle overloads on return types and clean up doxygen
 comment

Created using spr 1.3.5-bogner
---
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp  | 18 ++++----
 llvm/lib/Target/DirectX/DXILOpBuilder.h    | 15 +++---
 llvm/lib/Target/DirectX/DXILOpLowering.cpp |  3 +-
 llvm/utils/TableGen/DXILEmitter.cpp        | 53 +++++-----------------
 4 files changed, 33 insertions(+), 56 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index d43ac1119ff48..42df7c90cb337 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -285,9 +285,6 @@ static ShaderKind getShaderKindEnum(Triple::EnvironmentType EnvType) {
 /// the following prototype
 ///     OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)
 /// <param-types> are constructed from types in Prop.
-/// \param Prop  Structure containing DXIL Operation properties based on
-///               its specification in DXIL.td.
-/// \param OverloadTy Return type to be used to construct DXIL function type.
 static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
                                            LLVMContext &Context,
                                            Type *OverloadTy) {
@@ -355,11 +352,16 @@ static Error makeOpError(dxil::OpCode OpCode, Twine Msg) {
 }
 
 Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
-                                                ArrayRef<Value *> Args) {
+                                                ArrayRef<Value *> Args,
+                                                Type *RetTy) {
   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
 
   Type *OverloadTy = nullptr;
-  if (Prop->OverloadParamIndex > 0) {
+  if (Prop->OverloadParamIndex == 0) {
+    if (!RetTy)
+      return makeOpError(OpCode, "Op overloaded on unknown return type");
+    OverloadTy = RetTy;
+  } else if (Prop->OverloadParamIndex > 0) {
     // The index counts including the return type
     unsigned ArgIndex = Prop->OverloadParamIndex - 1;
     if (static_cast<unsigned>(ArgIndex) >= Args.size())
@@ -422,9 +424,9 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
   return B.CreateCall(DXILFn, OpArgs);
 }
 
-CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode,
-                                  ArrayRef<Value *> &Args) {
-  Expected<CallInst *> Result = tryCreateOp(OpCode, Args);
+CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> &Args,
+                                  Type *RetTy) {
+  Expected<CallInst *> Result = tryCreateOp(OpCode, Args, RetTy);
   if (Error E = Result.takeError())
     llvm_unreachable("Invalid arguments for operation");
   return *Result;
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index df55b69d77d53..ff66f39a3ceb3 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -33,20 +33,23 @@ class DXILOpBuilder {
 
   /// Create a call instruction for the given DXIL op. The arguments
   /// must be valid for an overload of the operation.
-  CallInst *createOp(dxil::OpCode Op, ArrayRef<Value *> &Args);
+  CallInst *createOp(dxil::OpCode Op, ArrayRef<Value *> &Args,
+                     Type *RetTy = nullptr);
 
 #define DXIL_OPCODE(Op, Name)                                                  \
-  CallInst *create##Name##Op(ArrayRef<Value *> &Args) {                        \
-    return createOp(dxil::OpCode(Op), Args);                                   \
+  CallInst *create##Name##Op(ArrayRef<Value *> &Args, Type *RetTy = nullptr) { \
+    return createOp(dxil::OpCode(Op), Args, RetTy);                            \
   }
 #include "DXILOperation.inc"
 
   /// Try to create a call instruction for the given DXIL op. Fails if the
   /// overload is invalid.
-  Expected<CallInst *> tryCreateOp(dxil::OpCode Op, ArrayRef<Value *> Args);
+  Expected<CallInst *> tryCreateOp(dxil::OpCode Op, ArrayRef<Value *> Args,
+                                   Type *RetTy = nullptr);
 #define DXIL_OPCODE(Op, Name)                                                  \
-  Expected<CallInst *> tryCreate##Name##Op(ArrayRef<Value *> &Args) {          \
-    return tryCreateOp(dxil::OpCode(Op), Args);                                \
+  Expected<CallInst *> tryCreate##Name##Op(ArrayRef<Value *> &Args,            \
+                                           Type *RetTy = nullptr) {            \
+    return tryCreateOp(dxil::OpCode(Op), Args, RetTy);                         \
   }
 #include "DXILOperation.inc"
 
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 4c27700e76f32..5f84cdcfda6de 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -89,7 +89,8 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
     } else
       Args.append(CI->arg_begin(), CI->arg_end());
 
-    Expected<CallInst *> OpCallOrErr = OpBuilder.tryCreateOp(DXILOp, Args);
+    Expected<CallInst *> OpCallOrErr = OpBuilder.tryCreateOp(DXILOp, Args,
+                                                             F.getReturnType());
     if (Error E = OpCallOrErr.takeError()) {
       std::string Message(toString(std::move(E)));
       DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message,
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 66b6206fedea5..c02b6ea64020a 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -132,54 +132,25 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   // resolve an LLVMMatchType in accordance with  convention outlined in
   // the comment before the definition of class LLVMMatchType in
   // llvm/IR/Intrinsics.td
-  SmallVector<int> OverloadParamIndices;
+  OverloadParamIndex = -1; // A sigil meaning none.
   for (unsigned i = 0; i < ParamTypeRecsSize; i++) {
     auto TR = ParamTypeRecs[i];
     // Track operation parameter indices of any overload types
     auto isAny = TR->getValueAsInt("isAny");
     if (isAny == 1) {
-      // All overload types in a DXIL Op are required to be of the same type.
-      if (!OverloadParamIndices.empty()) {
-        [[maybe_unused]] bool knownType = true;
-        // Ensure that the same overload type registered earlier is being used
-        for (auto Idx : OverloadParamIndices) {
-          if (TR != ParamTypeRecs[Idx]) {
-            knownType = false;
-            break;
-          }
-        }
-        assert(knownType && "Specification of multiple differing overload "
-                            "parameter types not yet supported");
-      } else {
-        // Skip the return value - having the overload index point at the return
-        // value makes it hard to determine the overload from an argument list,
-        // and we treat unoverloaded functions and those overloaded on return
-        // identically anyway.
-        if (i != 0)
-          OverloadParamIndices.push_back(i);
+      if (OverloadParamIndex != -1) {
+        assert(TR == ParamTypeRecs[OverloadParamIndex] &&
+               "Specification of multiple differing overload parameter types "
+               "is not supported");
       }
+      // Keep the earliest parameter index we see, but if it was the return type
+      // overwrite it with the first overloaded argument.
+      if (OverloadParamIndex <= 0)
+        OverloadParamIndex = i;
     }
-    // Populate OpTypes array according to the type specification
-    if (TR->isAnonymous()) {
-      // Check prior overload types exist
-      assert(!OverloadParamIndices.empty() &&
-             "No prior overloaded parameter found to match.");
-      // Get the parameter index of anonymous type, TR, references
-      auto OLParamIndex = TR->getValueAsInt("Number");
-      // Resolve and insert the type to that at OLParamIndex
-      OpTypes.emplace_back(ParamTypeRecs[OLParamIndex]);
-    } else {
-      // A non-anonymous type. Just record it in OpTypes
-      OpTypes.emplace_back(TR);
-    }
-  }
-
-  // Set the index of the overload parameter, if any.
-  OverloadParamIndex = -1; // default; indicating none
-  if (!OverloadParamIndices.empty()) {
-    assert(OverloadParamIndices.size() == 1 &&
-           "Multiple overload type specification not supported");
-    OverloadParamIndex = OverloadParamIndices[0];
+    if (TR->isAnonymous())
+      PrintFatalError(TR, "Only concrete types are allowed here");
+    OpTypes.emplace_back(TR);
   }
 
   // Get overload records



More information about the llvm-branch-commits mailing list