[llvm] [DirectX][DXIL] Distinguish return type for overload type resolution. (PR #85646)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 18 10:44:50 PDT 2024


https://github.com/bharadwajy updated https://github.com/llvm/llvm-project/pull/85646

>From 3017e429843000f38e367fce5966030289d28517 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 15 Mar 2024 18:36:21 -0400
Subject: [PATCH 1/2] [DirectX][DXIL] Distinguish return type for overload type
 resolution. Return type of DXIL Ops may be different from valid overload type
 of the parameters, if any. Such DXIL Ops are correctly represented in
 DXIL.td. However, DXILEmitter assumes the return type to be the same as
 parameter overload type, if one exists. This results in generation in
 incorrect overload index value in DXILOperation.inc for the DXIL Op and
 incorrect DXIL operation function call in DXILOpLowering pass.

This change distinguishes return types correctly from parameter overload
types in DXILEmitter backend to handle such DXIL ops.

Add specification for DXIL Op isinf and corresponding tests to verify
the above change.

Fixes issue #85125
---
 llvm/lib/Target/DirectX/DXIL.td            |  3 ++
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp  | 49 ++++++++++++----------
 llvm/lib/Target/DirectX/DXILOpBuilder.h    |  3 +-
 llvm/lib/Target/DirectX/DXILOpLowering.cpp |  7 +++-
 llvm/test/CodeGen/DirectX/isinf.ll         | 25 +++++++++++
 llvm/test/CodeGen/DirectX/isinf_error.ll   | 13 ++++++
 llvm/utils/TableGen/DXILEmitter.cpp        |  7 +++-
 7 files changed, 81 insertions(+), 26 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/isinf.ll
 create mode 100644 llvm/test/CodeGen/DirectX/isinf_error.ll

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 216fa5b10c8f4d..36eb29d53766f0 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -255,6 +255,9 @@ class DXILOpMapping<int opCode, DXILOpClass opClass,
 }
 
 // Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
+def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf,
+                         "Determines if the specified value is infinite.",
+                         [llvm_i1_ty, llvm_halforfloat_ty]>;
 def Sin  : DXILOpMapping<13, unary, int_sin,
                          "Returns sine(theta) for theta in radians.",
                          [llvm_halforfloat_ty, LLVMMatchType<0>]>;
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 11b24d04492368..0c05cf6ecdf195 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -229,13 +229,13 @@ static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
 ///               its specification in DXIL.td.
 /// \param OverloadTy Return type to be used to construct DXIL function type.
 static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
-                                           Type *OverloadTy) {
+                                           Type *ReturnTy, Type *OverloadTy) {
   SmallVector<Type *> ArgTys;
 
   auto ParamKinds = getOpCodeParameterKind(*Prop);
 
-  // Add OverloadTy as return type of the function
-  ArgTys.emplace_back(OverloadTy);
+  // Add ReturnTy as return type of the function
+  ArgTys.emplace_back(ReturnTy);
 
   // Add DXIL Opcode value type viz., Int32 as first argument
   ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext()));
@@ -249,34 +249,39 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
       ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
 }
 
-static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
-                                                Type *OverloadTy, Module &M) {
-  const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
+namespace llvm {
+namespace dxil {
+
+// Create an instruction that calls DXIL Op with return type, specified opcode,
+// and call arguments.
+// \param OpCode Opcode of the DXIL Op call constructed
+// \param ReturnTy Return type of the DXIL Op call constructed
+// \param OverloadTy Overload type of the DXIL Op call constructed
+// \ret DXIL Op call constructed
+CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
+                                          Type *OverloadTy,
+                                          llvm::iterator_range<Use *> Args) {
+  const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
 
   OverloadKind Kind = getOverloadKind(OverloadTy);
   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
     report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);
   }
 
-  std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
-  // Dependent on name to dedup.
-  if (auto *Fn = M.getFunction(FnName))
-    return FunctionCallee(Fn);
-
-  FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy);
-  return M.getOrInsertFunction(FnName, DXILOpFT);
-}
-
-namespace llvm {
-namespace dxil {
-
-CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
-                                          llvm::iterator_range<Use *> Args) {
-  auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M);
+  std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop);
+  FunctionCallee DXILFn;
+  // Get the function with name DXILFnName, if one exists
+  if (auto *Func = M.getFunction(DXILFnName)) {
+    DXILFn = FunctionCallee(Func);
+  } else {
+    // Construct and add a function with name DXILFnName
+    FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
+    DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
+  }
   SmallVector<Value *> FullArgs;
   FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
   FullArgs.append(Args.begin(), Args.end());
-  return B.CreateCall(Fn, FullArgs);
+  return B.CreateCall(DXILFn, FullArgs);
 }
 
 Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 1c15f109184adf..89e91e2a0784db 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -29,7 +29,8 @@ namespace dxil {
 class DXILOpBuilder {
 public:
   DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {}
-  CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
+  CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
+                             Type *OverloadTy,
                              llvm::iterator_range<Use *> Args);
   Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
   static const char *getOpCodeName(dxil::OpCode DXILOp);
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index e5c2042e7d16ae..9af94834565457 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -44,7 +44,12 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
     Args.emplace_back(DXILOpArg);
     Args.append(CI->arg_begin(), CI->arg_end());
     B.SetInsertPoint(CI);
-    CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, OverloadTy, CI->args());
+    // Return type of F may be different from that of its arguments.
+    // It may not correspond to any overload type (if one exists) of the
+    // parameters. Pass the return type and overload type separately to handle
+    // such situations.
+    CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(),
+                                              OverloadTy, CI->args());
 
     CI->replaceAllUsesWith(DXILCI);
     CI->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/isinf.ll b/llvm/test/CodeGen/DirectX/isinf.ll
new file mode 100644
index 00000000000000..e2975da90bfc1b
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/isinf.ll
@@ -0,0 +1,25 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for isinf are generated for float and half.
+; CHECK: call i1 @dx.op.isSpecialFloat.f32(i32 9, float %{{.*}})
+; CHECK: call i1 @dx.op.isSpecialFloat.f16(i32 9, half %{{.*}})
+
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @isinf_float(float noundef %a) #0 {
+entry:
+  %a.addr = alloca float, align 4
+  store float %a, ptr %a.addr, align 4
+  %0 = load float, ptr %a.addr, align 4
+  %dx.isinf = call i1 @llvm.dx.isinf.f32(float %0)
+  ret i1 %dx.isinf
+}
+
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @isinf_half(half noundef %p0) #0 {
+entry:
+  %p0.addr = alloca half, align 2
+  store half %p0, ptr %p0.addr, align 2
+  %0 = load half, ptr %p0.addr, align 2
+  %dx.isinf = call i1 @llvm.dx.isinf.f16(half %0)
+  ret i1 %dx.isinf
+}
diff --git a/llvm/test/CodeGen/DirectX/isinf_error.ll b/llvm/test/CodeGen/DirectX/isinf_error.ll
new file mode 100644
index 00000000000000..95b2d0cabcc43b
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/isinf_error.ll
@@ -0,0 +1,13 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation isinf does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload Type
+
+define noundef i1 @isinf_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %dx.isinf = call i1 @llvm.dx.isinf.f64(double %0)
+  ret i1 %dx.isinf
+}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 59089929837ebb..cdfc8d3d174487 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -119,7 +119,7 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   // Populate OpTypes with return type and parameter types
 
   // Parameter indices of overloaded parameters.
-  // This vector contains overload parameters in the order order used to
+  // This vector contains overload parameters in the order used to
   // resolve an LLVMMatchType in accordance with  convention outlined in
   // the comment before the definition of class LLVMMatchType in
   // llvm/IR/Intrinsics.td
@@ -398,10 +398,13 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
 
   OS << "  static const OpCodeProperty OpCodeProps[] = {\n";
   for (auto &Op : Ops) {
+    // If no overload parameter exists, treat the return type as overload
+    // parameter to emit the appropriate overload kind enum.
+    auto OLParamIdx = (Op.OverloadParamIndex < 0) ? 0 : Op.OverloadParamIndex;
     OS << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
        << ", OpCodeClass::" << Op.OpClass << ", "
        << OpClassStrings.get(Op.OpClass.data()) << ", "
-       << getOverloadKindStr(Op.OpTypes[0]) << ", "
+       << getOverloadKindStr(Op.OpTypes[OLParamIdx]) << ", "
        << emitDXILOperationAttr(Op.OpAttributes) << ", "
        << Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
        << Parameters.get(ParameterMap[Op.OpClass]) << " },\n";

>From e41a3eae7cde798dfb24c490e800be0ec2312417 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Mon, 18 Mar 2024 13:20:07 -0400
Subject: [PATCH 2/2] Incorporate PR feedback. First parameter of the DXIL Op
 is the overoad type.

---
 llvm/utils/TableGen/DXILEmitter.cpp | 13 ++++++++++---
 1 file changed, 10 insertions(+), 3 deletions(-)

diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index cdfc8d3d174487..af1efb8aa99f73 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -398,9 +398,16 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
 
   OS << "  static const OpCodeProperty OpCodeProps[] = {\n";
   for (auto &Op : Ops) {
-    // If no overload parameter exists, treat the return type as overload
-    // parameter to emit the appropriate overload kind enum.
-    auto OLParamIdx = (Op.OverloadParamIndex < 0) ? 0 : Op.OverloadParamIndex;
+    // Consider Op.OverloadParamIndex as the overload parameter index, by
+    // default
+    auto OLParamIdx = Op.OverloadParamIndex;
+    // If no overload parameter index is set, treat first parameter type as
+    // overload type - unless the Op has no parameters, in which case treat the
+    // return type - as overload parameter to emit the appropriate overload kind
+    // enum.
+    if (OLParamIdx < 0) {
+      OLParamIdx = (Op.OpTypes.size() > 1) ? 1 : 0;
+    }
     OS << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
        << ", OpCodeClass::" << Op.OpClass << ", "
        << OpClassStrings.get(Op.OpClass.data()) << ", "



More information about the llvm-commits mailing list