[llvm] [DirectX][DXIL] Distinguish return type for overload type resolution. (PR #85646)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 18 07:36:50 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: S. Bharadwaj Yadavalli (bharadwajy)

<details>
<summary>Changes</summary>

Return type of DXIL Ops may be different from valid overload type of the parameters, if any. Such DXIL Ops are correctly represented in DXIL.td. However, DXILEmitter assumes the return type to be the same as parameter overload type, if one exists. This results in generation in incorrect overload index value in DXILOperation.inc for the DXIL Op and incorrect DXIL operation function call in DXILOpLowering pass.

This change distinguishes return types correctly from parameter overload types in DXILEmitter backend to handle such DXIL ops.

Add specification for DXIL Op isinf and corresponding tests to verify the above change.

Fixes issue #<!-- -->85125

---
Full diff: https://github.com/llvm/llvm-project/pull/85646.diff


7 Files Affected:

- (modified) llvm/lib/Target/DirectX/DXIL.td (+3) 
- (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+27-22) 
- (modified) llvm/lib/Target/DirectX/DXILOpBuilder.h (+2-1) 
- (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+6-1) 
- (added) llvm/test/CodeGen/DirectX/isinf.ll (+25) 
- (added) llvm/test/CodeGen/DirectX/isinf_error.ll (+13) 
- (modified) llvm/utils/TableGen/DXILEmitter.cpp (+5-2) 


``````````diff
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 216fa5b10c8f4d..36eb29d53766f0 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -255,6 +255,9 @@ class DXILOpMapping<int opCode, DXILOpClass opClass,
 }
 
 // Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
+def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf,
+                         "Determines if the specified value is infinite.",
+                         [llvm_i1_ty, llvm_halforfloat_ty]>;
 def Sin  : DXILOpMapping<13, unary, int_sin,
                          "Returns sine(theta) for theta in radians.",
                          [llvm_halforfloat_ty, LLVMMatchType<0>]>;
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 11b24d04492368..0c05cf6ecdf195 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -229,13 +229,13 @@ static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
 ///               its specification in DXIL.td.
 /// \param OverloadTy Return type to be used to construct DXIL function type.
 static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
-                                           Type *OverloadTy) {
+                                           Type *ReturnTy, Type *OverloadTy) {
   SmallVector<Type *> ArgTys;
 
   auto ParamKinds = getOpCodeParameterKind(*Prop);
 
-  // Add OverloadTy as return type of the function
-  ArgTys.emplace_back(OverloadTy);
+  // Add ReturnTy as return type of the function
+  ArgTys.emplace_back(ReturnTy);
 
   // Add DXIL Opcode value type viz., Int32 as first argument
   ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext()));
@@ -249,34 +249,39 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
       ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
 }
 
-static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
-                                                Type *OverloadTy, Module &M) {
-  const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
+namespace llvm {
+namespace dxil {
+
+// Create an instruction that calls DXIL Op with return type, specified opcode,
+// and call arguments.
+// \param OpCode Opcode of the DXIL Op call constructed
+// \param ReturnTy Return type of the DXIL Op call constructed
+// \param OverloadTy Overload type of the DXIL Op call constructed
+// \ret DXIL Op call constructed
+CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
+                                          Type *OverloadTy,
+                                          llvm::iterator_range<Use *> Args) {
+  const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
 
   OverloadKind Kind = getOverloadKind(OverloadTy);
   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
     report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);
   }
 
-  std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
-  // Dependent on name to dedup.
-  if (auto *Fn = M.getFunction(FnName))
-    return FunctionCallee(Fn);
-
-  FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy);
-  return M.getOrInsertFunction(FnName, DXILOpFT);
-}
-
-namespace llvm {
-namespace dxil {
-
-CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
-                                          llvm::iterator_range<Use *> Args) {
-  auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M);
+  std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop);
+  FunctionCallee DXILFn;
+  // Get the function with name DXILFnName, if one exists
+  if (auto *Func = M.getFunction(DXILFnName)) {
+    DXILFn = FunctionCallee(Func);
+  } else {
+    // Construct and add a function with name DXILFnName
+    FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
+    DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
+  }
   SmallVector<Value *> FullArgs;
   FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
   FullArgs.append(Args.begin(), Args.end());
-  return B.CreateCall(Fn, FullArgs);
+  return B.CreateCall(DXILFn, FullArgs);
 }
 
 Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 1c15f109184adf..89e91e2a0784db 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -29,7 +29,8 @@ namespace dxil {
 class DXILOpBuilder {
 public:
   DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {}
-  CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
+  CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
+                             Type *OverloadTy,
                              llvm::iterator_range<Use *> Args);
   Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
   static const char *getOpCodeName(dxil::OpCode DXILOp);
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index e5c2042e7d16ae..9af94834565457 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -44,7 +44,12 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
     Args.emplace_back(DXILOpArg);
     Args.append(CI->arg_begin(), CI->arg_end());
     B.SetInsertPoint(CI);
-    CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, OverloadTy, CI->args());
+    // Return type of F may be different from that of its arguments.
+    // It may not correspond to any overload type (if one exists) of the
+    // parameters. Pass the return type and overload type separately to handle
+    // such situations.
+    CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(),
+                                              OverloadTy, CI->args());
 
     CI->replaceAllUsesWith(DXILCI);
     CI->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/isinf.ll b/llvm/test/CodeGen/DirectX/isinf.ll
new file mode 100644
index 00000000000000..e2975da90bfc1b
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/isinf.ll
@@ -0,0 +1,25 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for isinf are generated for float and half.
+; CHECK: call i1 @dx.op.isSpecialFloat.f32(i32 9, float %{{.*}})
+; CHECK: call i1 @dx.op.isSpecialFloat.f16(i32 9, half %{{.*}})
+
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @isinf_float(float noundef %a) #0 {
+entry:
+  %a.addr = alloca float, align 4
+  store float %a, ptr %a.addr, align 4
+  %0 = load float, ptr %a.addr, align 4
+  %dx.isinf = call i1 @llvm.dx.isinf.f32(float %0)
+  ret i1 %dx.isinf
+}
+
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @isinf_half(half noundef %p0) #0 {
+entry:
+  %p0.addr = alloca half, align 2
+  store half %p0, ptr %p0.addr, align 2
+  %0 = load half, ptr %p0.addr, align 2
+  %dx.isinf = call i1 @llvm.dx.isinf.f16(half %0)
+  ret i1 %dx.isinf
+}
diff --git a/llvm/test/CodeGen/DirectX/isinf_error.ll b/llvm/test/CodeGen/DirectX/isinf_error.ll
new file mode 100644
index 00000000000000..95b2d0cabcc43b
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/isinf_error.ll
@@ -0,0 +1,13 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation isinf does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload Type
+
+define noundef i1 @isinf_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %dx.isinf = call i1 @llvm.dx.isinf.f64(double %0)
+  ret i1 %dx.isinf
+}
diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 59089929837ebb..cdfc8d3d174487 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -119,7 +119,7 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   // Populate OpTypes with return type and parameter types
 
   // Parameter indices of overloaded parameters.
-  // This vector contains overload parameters in the order order used to
+  // This vector contains overload parameters in the order used to
   // resolve an LLVMMatchType in accordance with  convention outlined in
   // the comment before the definition of class LLVMMatchType in
   // llvm/IR/Intrinsics.td
@@ -398,10 +398,13 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
 
   OS << "  static const OpCodeProperty OpCodeProps[] = {\n";
   for (auto &Op : Ops) {
+    // If no overload parameter exists, treat the return type as overload
+    // parameter to emit the appropriate overload kind enum.
+    auto OLParamIdx = (Op.OverloadParamIndex < 0) ? 0 : Op.OverloadParamIndex;
     OS << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
        << ", OpCodeClass::" << Op.OpClass << ", "
        << OpClassStrings.get(Op.OpClass.data()) << ", "
-       << getOverloadKindStr(Op.OpTypes[0]) << ", "
+       << getOverloadKindStr(Op.OpTypes[OLParamIdx]) << ", "
        << emitDXILOperationAttr(Op.OpAttributes) << ", "
        << Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
        << Parameters.get(ParameterMap[Op.OpClass]) << " },\n";

``````````

</details>


https://github.com/llvm/llvm-project/pull/85646


More information about the llvm-commits mailing list