[llvm] 3f39571 - [DirectX][DXIL] Distinguish return type for overload type resolution. (#85646)

via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 20 11:48:19 PDT 2024


Author: S. Bharadwaj Yadavalli
Date: 2024-03-20T14:48:16-04:00
New Revision: 3f39571228fe2cf402e6ea5727cd5b32f9299356

URL: https://github.com/llvm/llvm-project/commit/3f39571228fe2cf402e6ea5727cd5b32f9299356
DIFF: https://github.com/llvm/llvm-project/commit/3f39571228fe2cf402e6ea5727cd5b32f9299356.diff

LOG: [DirectX][DXIL] Distinguish return type for overload type resolution. (#85646)

Return type of DXIL Ops may be different from valid overload type of the
parameters, if any. Such DXIL Ops are correctly represented in DXIL.td.
However, DXILEmitter assumes the return type to be the same as parameter
overload type, if one exists. This results in generation in incorrect
overload index value in DXILOperation.inc for the DXIL Op and incorrect
DXIL operation function call in DXILOpLowering pass.

This change distinguishes return types correctly from parameter overload
types in DXILEmitter backend to handle such DXIL ops.

Add specification for DXIL Op `isinf` and corresponding tests to verify
the above change.

Fixes issue #85125

Added: 
    llvm/test/CodeGen/DirectX/isinf.ll
    llvm/test/CodeGen/DirectX/isinf_error.ll

Modified: 
    llvm/lib/Target/DirectX/DXIL.td
    llvm/lib/Target/DirectX/DXILOpBuilder.cpp
    llvm/lib/Target/DirectX/DXILOpBuilder.h
    llvm/lib/Target/DirectX/DXILOpLowering.cpp
    llvm/utils/TableGen/DXILEmitter.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 216fa5b10c8f4d..36eb29d53766f0 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -255,6 +255,9 @@ class DXILOpMapping<int opCode, DXILOpClass opClass,
 }
 
 // Concrete definition of DXIL Operation mapping to corresponding LLVM intrinsic
+def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf,
+                         "Determines if the specified value is infinite.",
+                         [llvm_i1_ty, llvm_halforfloat_ty]>;
 def Sin  : DXILOpMapping<13, unary, int_sin,
                          "Returns sine(theta) for theta in radians.",
                          [llvm_halforfloat_ty, LLVMMatchType<0>]>;

diff  --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 11b24d04492368..a1eacc2d48009c 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -229,13 +229,13 @@ static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
 ///               its specification in DXIL.td.
 /// \param OverloadTy Return type to be used to construct DXIL function type.
 static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
-                                           Type *OverloadTy) {
+                                           Type *ReturnTy, Type *OverloadTy) {
   SmallVector<Type *> ArgTys;
 
   auto ParamKinds = getOpCodeParameterKind(*Prop);
 
-  // Add OverloadTy as return type of the function
-  ArgTys.emplace_back(OverloadTy);
+  // Add ReturnTy as return type of the function
+  ArgTys.emplace_back(ReturnTy);
 
   // Add DXIL Opcode value type viz., Int32 as first argument
   ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext()));
@@ -249,34 +249,33 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
       ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
 }
 
-static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp,
-                                                Type *OverloadTy, Module &M) {
-  const OpCodeProperty *Prop = getOpCodeProperty(DXILOp);
+namespace llvm {
+namespace dxil {
+
+CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
+                                          Type *OverloadTy,
+                                          llvm::iterator_range<Use *> Args) {
+  const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
 
   OverloadKind Kind = getOverloadKind(OverloadTy);
   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
     report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);
   }
 
-  std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop);
-  // Dependent on name to dedup.
-  if (auto *Fn = M.getFunction(FnName))
-    return FunctionCallee(Fn);
-
-  FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy);
-  return M.getOrInsertFunction(FnName, DXILOpFT);
-}
-
-namespace llvm {
-namespace dxil {
-
-CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
-                                          llvm::iterator_range<Use *> Args) {
-  auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M);
+  std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop);
+  FunctionCallee DXILFn;
+  // Get the function with name DXILFnName, if one exists
+  if (auto *Func = M.getFunction(DXILFnName)) {
+    DXILFn = FunctionCallee(Func);
+  } else {
+    // Construct and add a function with name DXILFnName
+    FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
+    DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
+  }
   SmallVector<Value *> FullArgs;
   FullArgs.emplace_back(B.getInt32((int32_t)OpCode));
   FullArgs.append(Args.begin(), Args.end());
-  return B.CreateCall(Fn, FullArgs);
+  return B.CreateCall(DXILFn, FullArgs);
 }
 
 Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {

diff  --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 1c15f109184adf..f3abcc6e02a4e3 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -29,7 +29,13 @@ namespace dxil {
 class DXILOpBuilder {
 public:
   DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) {}
-  CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy,
+  /// Create an instruction that calls DXIL Op with return type, specified
+  /// opcode, and call arguments. \param OpCode Opcode of the DXIL Op call
+  /// constructed \param ReturnTy Return type of the DXIL Op call constructed
+  /// \param OverloadTy Overload type of the DXIL Op call constructed
+  /// \return DXIL Op call constructed
+  CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
+                             Type *OverloadTy,
                              llvm::iterator_range<Use *> Args);
   Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT);
   static const char *getOpCodeName(dxil::OpCode DXILOp);

diff  --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index e5c2042e7d16ae..3e334b0ec298d3 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -32,7 +32,6 @@ using namespace llvm::dxil;
 
 static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
   IRBuilder<> B(M.getContext());
-  Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
   DXILOpBuilder DXILB(M, B);
   Type *OverloadTy = DXILB.getOverloadTy(DXILOp, F.getFunctionType());
   for (User *U : make_early_inc_range(F.users())) {
@@ -40,11 +39,9 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
     if (!CI)
       continue;
 
-    SmallVector<Value *> Args;
-    Args.emplace_back(DXILOpArg);
-    Args.append(CI->arg_begin(), CI->arg_end());
     B.SetInsertPoint(CI);
-    CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, OverloadTy, CI->args());
+    CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(),
+                                              OverloadTy, CI->args());
 
     CI->replaceAllUsesWith(DXILCI);
     CI->eraseFromParent();

diff  --git a/llvm/test/CodeGen/DirectX/isinf.ll b/llvm/test/CodeGen/DirectX/isinf.ll
new file mode 100644
index 00000000000000..e2975da90bfc1b
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/isinf.ll
@@ -0,0 +1,25 @@
+; RUN: opt -S -dxil-op-lower < %s | FileCheck %s
+
+; Make sure dxil operation function calls for isinf are generated for float and half.
+; CHECK: call i1 @dx.op.isSpecialFloat.f32(i32 9, float %{{.*}})
+; CHECK: call i1 @dx.op.isSpecialFloat.f16(i32 9, half %{{.*}})
+
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @isinf_float(float noundef %a) #0 {
+entry:
+  %a.addr = alloca float, align 4
+  store float %a, ptr %a.addr, align 4
+  %0 = load float, ptr %a.addr, align 4
+  %dx.isinf = call i1 @llvm.dx.isinf.f32(float %0)
+  ret i1 %dx.isinf
+}
+
+; Function Attrs: noinline nounwind optnone
+define noundef i1 @isinf_half(half noundef %p0) #0 {
+entry:
+  %p0.addr = alloca half, align 2
+  store half %p0, ptr %p0.addr, align 2
+  %0 = load half, ptr %p0.addr, align 2
+  %dx.isinf = call i1 @llvm.dx.isinf.f16(half %0)
+  ret i1 %dx.isinf
+}

diff  --git a/llvm/test/CodeGen/DirectX/isinf_error.ll b/llvm/test/CodeGen/DirectX/isinf_error.ll
new file mode 100644
index 00000000000000..95b2d0cabcc43b
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/isinf_error.ll
@@ -0,0 +1,13 @@
+; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s
+
+; DXIL operation isinf does not support double overload type
+; CHECK: LLVM ERROR: Invalid Overload Type
+
+define noundef i1 @isinf_double(double noundef %a) #0 {
+entry:
+  %a.addr = alloca double, align 8
+  store double %a, ptr %a.addr, align 8
+  %0 = load double, ptr %a.addr, align 8
+  %dx.isinf = call i1 @llvm.dx.isinf.f64(double %0)
+  ret i1 %dx.isinf
+}

diff  --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp
index 59089929837ebb..af1efb8aa99f73 100644
--- a/llvm/utils/TableGen/DXILEmitter.cpp
+++ b/llvm/utils/TableGen/DXILEmitter.cpp
@@ -119,7 +119,7 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) {
   // Populate OpTypes with return type and parameter types
 
   // Parameter indices of overloaded parameters.
-  // This vector contains overload parameters in the order order used to
+  // This vector contains overload parameters in the order used to
   // resolve an LLVMMatchType in accordance with  convention outlined in
   // the comment before the definition of class LLVMMatchType in
   // llvm/IR/Intrinsics.td
@@ -398,10 +398,20 @@ static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
 
   OS << "  static const OpCodeProperty OpCodeProps[] = {\n";
   for (auto &Op : Ops) {
+    // Consider Op.OverloadParamIndex as the overload parameter index, by
+    // default
+    auto OLParamIdx = Op.OverloadParamIndex;
+    // If no overload parameter index is set, treat first parameter type as
+    // overload type - unless the Op has no parameters, in which case treat the
+    // return type - as overload parameter to emit the appropriate overload kind
+    // enum.
+    if (OLParamIdx < 0) {
+      OLParamIdx = (Op.OpTypes.size() > 1) ? 1 : 0;
+    }
     OS << "  { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
        << ", OpCodeClass::" << Op.OpClass << ", "
        << OpClassStrings.get(Op.OpClass.data()) << ", "
-       << getOverloadKindStr(Op.OpTypes[0]) << ", "
+       << getOverloadKindStr(Op.OpTypes[OLParamIdx]) << ", "
        << emitDXILOperationAttr(Op.OpAttributes) << ", "
        << Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
        << Parameters.get(ParameterMap[Op.OpClass]) << " },\n";


        


More information about the llvm-commits mailing list