[clang] [llvm] Inbelic/as double (PR #114847)

Finn Plummer via cfe-commits cfe-commits at lists.llvm.org
Mon Nov 4 10:29:55 PST 2024


https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/114847

None

>From f340a6f0421693bd3489adc1c68983dfae9646dd Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 4 Nov 2024 17:38:36 +0000
Subject: [PATCH 1/4] [NFC][Scalarizer][TargetTransformInfo] Add
 `isVectorIntrinsicWithOverloadTypeAtArg`

This changes allows target intrinsic to specify overloaded types.

This change will let us add scalarization for `asdouble`:
---
 llvm/include/llvm/Analysis/TargetTransformInfo.h   | 14 ++++++++++++++
 .../llvm/Analysis/TargetTransformInfoImpl.h        |  6 ++++++
 llvm/include/llvm/CodeGen/BasicTTIImpl.h           |  6 ++++++
 llvm/lib/Analysis/TargetTransformInfo.cpp          |  6 ++++++
 .../Target/DirectX/DirectXTargetTransformInfo.cpp  |  8 ++++++++
 .../Target/DirectX/DirectXTargetTransformInfo.h    |  3 +++
 llvm/lib/Transforms/Scalar/Scalarizer.cpp          |  9 ++++++---
 7 files changed, 49 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 0459941fe05cdc..796b4011d71c0c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -896,6 +896,10 @@ class TargetTransformInfo {
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx) const;
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const;
+
   /// Estimate the overhead of scalarizing an instruction. Insert and Extract
   /// are set if the demanded result elements need to be inserted and/or
   /// extracted from vectors.
@@ -1969,6 +1973,9 @@ class TargetTransformInfo::Concept {
   virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
   virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                                   unsigned ScalarOpdIdx) = 0;
+  virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                                      unsigned ScalarOpdIdx,
+                                                      bool Default) = 0;
   virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
                                                    const APInt &DemandedElts,
                                                    bool Insert, bool Extract,
@@ -2530,6 +2537,13 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) override {
+    return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+                                                       Default);
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index dbdfb4d8cdfa32..42d1082cf4d9eb 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -392,6 +392,12 @@ class TargetTransformInfoImplBase {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const {
+    return Default;
+  }
+
   InstructionCost getScalarizationOverhead(VectorType *Ty,
                                            const APInt &DemandedElts,
                                            bool Insert, bool Extract,
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index db3b5cddd7c1c3..b2841e778947dd 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -798,6 +798,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return false;
   }
 
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default) const {
+    return Default;
+  }
+
   /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
   InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
                                            bool Extract,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index a47462b61e03b2..bf9733f971fdac 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -612,6 +612,12 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
   return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
 }
 
+bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) const {
+  return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx,
+                                                         Default);
+}
+
 InstructionCost TargetTransformInfo::getScalarizationOverhead(
     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
     TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 231afd8ae3eeaf..2be5fc4de2409c 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -25,6 +25,14 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
   }
 }
 
+bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(
+    Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) {
+  switch (ID) {
+  default:
+    return Default;
+  }
+}
+
 bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
     Intrinsic::ID ID) const {
   switch (ID) {
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
index 30b57ed97d6370..ff82b7404ca58a 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
@@ -37,6 +37,9 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
   bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
   bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
                                           unsigned ScalarOpdIdx);
+  bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
+                                              unsigned ScalarOpdIdx,
+                                              bool Default);
 };
 } // namespace llvm
 
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 772f4c6c35ddec..719dce704872ae 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -727,7 +727,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
 
   SmallVector<llvm::Type *, 3> Tys;
   // Add return type if intrinsic is overloaded on it.
-  if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
+  if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+          ID, -1, isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)))
     Tys.push_back(VS->SplitTy);
 
   if (AreAllVectorsOfMatchingSize) {
@@ -767,13 +768,15 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
       }
 
       Scattered[I] = scatter(&CI, OpI, *OpVS);
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
+      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I))) {
         OverloadIdx[I] = Tys.size();
         Tys.push_back(OpVS->SplitTy);
       }
     } else {
       ScalarOperands[I] = OpI;
-      if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
+      if (TTI->isVectorIntrinsicWithOverloadTypeAtArg(
+              ID, I, isVectorIntrinsicWithOverloadTypeAtArg(ID, I)))
         Tys.push_back(OpI->getType());
     }
   }

>From 332f80b7ab8b5495e3491298e7aee2e5b8ede2b4 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Mon, 28 Oct 2024 23:09:06 +0000
Subject: [PATCH 2/4] [HLSL] Add `asdouble` builtin

- define intrinsic as builtin in Builtins.td
  - link intrinsic in hlsl_intrinsics.h
  - add semantic analysis to SemaHLSL.cpp

  - add basic sema checking to asdouble-errors.hlsl
---
 clang/include/clang/Basic/Builtins.td          |  6 ++++++
 clang/lib/Headers/hlsl/hlsl_intrinsics.h       | 18 ++++++++++++++++++
 clang/lib/Sema/SemaHLSL.cpp                    | 17 +++++++++++++++++
 .../SemaHLSL/BuiltIns/asdouble-errors.hlsl     | 16 ++++++++++++++++
 4 files changed, 57 insertions(+)
 create mode 100644 clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..37ef0bf7324ffb 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4744,6 +4744,12 @@ def HLSLAny : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "bool(...)";
 }
 
+def HLSLAsDouble : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_asdouble"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..7dd9c136d1d3f4 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -361,6 +361,24 @@ bool any(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_any)
 bool any(double4);
 
+//===----------------------------------------------------------------------===//
+// asdouble builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn double asdouble(uint LowBits, uint HighBits)
+/// \brief Reinterprets a cast value (two 32-bit values) into a double.
+/// \param LowBits The low 32-bit pattern of the input value.
+/// \param HighBits The high 32-bit pattern of the input value.
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double asdouble(uint, uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double2 asdouble(uint2, uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double3 asdouble(uint3, uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asdouble)
+double4 asdouble(uint4, uint4);
+
 //===----------------------------------------------------------------------===//
 // asfloat builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index c6627b0e993226..69de0294cb7c7c 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1870,6 +1870,23 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_asdouble: {
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
+    if (CheckUnsignedIntRepresentation(&SemaRef, TheCall))
+      return true;
+
+    // Set the return type to be a scalar or vector of same length of double
+    ASTContext &Ctx = SemaRef.getASTContext();
+    auto *VTy = TheCall->getArg(0)->getType()->getAs<VectorType>();
+
+    QualType ResultType =
+        VTy ? Ctx.getVectorType(Ctx.DoubleTy, VTy->getNumElements(),
+                                VTy->getVectorKind())
+            : Ctx.DoubleTy;
+    TheCall->setType(ResultType);
+    break;
+  }
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
diff --git a/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl
new file mode 100644
index 00000000000000..c6b57c76a1e2b3
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/asdouble-errors.hlsl
@@ -0,0 +1,16 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+double test_too_few_arg() {
+  return __builtin_hlsl_asdouble();
+  // expected-error at -1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+double test_too_few_arg_1(uint p0) {
+  return __builtin_hlsl_asdouble(p0);
+  // expected-error at -1 {{too few arguments to function call, expected 2, have 1}}
+}
+
+double test_too_many_arg(uint p0) {
+  return __builtin_hlsl_asdouble(p0, p0, p0);
+  // expected-error at -1 {{too many arguments to function call, expected 2, have 3}}
+}

>From f765c06a1424bf5b97ba60b6dc03b5dcde992154 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Sat, 2 Nov 2024 01:50:08 +0000
Subject: [PATCH 3/4] [HLSL] Add codegen to llvm intrinsics

---
 clang/lib/CodeGen/CGBuiltin.cpp               | 32 +++++++++++++++++++
 clang/test/CodeGenHLSL/builtins/asdouble.hlsl | 29 +++++++++++++++++
 2 files changed, 61 insertions(+)
 create mode 100644 clang/test/CodeGenHLSL/builtins/asdouble.hlsl

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 28f28c70b5ae52..fee0d258366b5c 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18634,6 +18634,36 @@ static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
   return RT.getUDotIntrinsic();
 }
 
+Value *handleAsDoubleBuiltin(CodeGenFunction &CGF, const CallExpr *E) {
+  assert((E->getArg(0)->getType()->hasUnsignedIntegerRepresentation() &&
+          E->getArg(1)->getType()->hasUnsignedIntegerRepresentation()) &&
+         "asdouble operands types mismatch");
+  Value *OpLowBits = CGF.EmitScalarExpr(E->getArg(0));
+  Value *OpHighBits = CGF.EmitScalarExpr(E->getArg(1));
+
+  llvm::Type *ResultType = CGF.DoubleTy;
+  int N = 1;
+  if (auto *VTy = E->getArg(0)->getType()->getAs<clang::VectorType>()) {
+    N = VTy->getNumElements();
+    ResultType = llvm::FixedVectorType::get(CGF.DoubleTy, N);
+  }
+
+  if (!E->getArg(0)->getType()->isVectorType()) {
+    OpLowBits = CGF.Builder.CreateVectorSplat(1, OpLowBits);
+    OpHighBits = CGF.Builder.CreateVectorSplat(1, OpHighBits);
+  }
+
+  llvm::SmallVector<int> Mask;
+  for (int i = 0; i < N; i++) {
+    Mask.push_back(i);
+    Mask.push_back(i + N);
+  }
+
+  Value *BitVec = CGF.Builder.CreateShuffleVector(OpLowBits, OpHighBits, Mask);
+
+  return CGF.Builder.CreateBitCast(BitVec, ResultType);
+}
+
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
                                             const CallExpr *E,
                                             ReturnValueSlot ReturnValue) {
@@ -18655,6 +18685,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
         CGM.getHLSLRuntime().getAnyIntrinsic(), ArrayRef<Value *>{Op0}, nullptr,
         "hlsl.any");
   }
+  case Builtin::BI__builtin_hlsl_asdouble:
+    return handleAsDoubleBuiltin(*this, E);
   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
     Value *OpX = EmitScalarExpr(E->getArg(0));
     Value *OpMin = EmitScalarExpr(E->getArg(1));
diff --git a/clang/test/CodeGenHLSL/builtins/asdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asdouble.hlsl
new file mode 100644
index 00000000000000..9d9ef048ee35fe
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/asdouble.hlsl
@@ -0,0 +1,29 @@
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s
+
+// Test lowering of asdouble expansion to shuffle/bitcast and splat when required
+
+// CHECK-LABEL: test_uint
+double test_uint(uint low, uint high) {
+  // CHECK: %[[LOW_INSERT:.*]] = insertelement <1 x i32>
+  // CHECK: %[[LOW_SHUFFLE:.*]] = shufflevector <1 x i32> %[[LOW_INSERT]], {{.*}} zeroinitializer
+  // CHECK: %[[HIGH_INSERT:.*]] = insertelement <1 x i32>
+  // CHECK: %[[HIGH_SHUFFLE:.*]] = shufflevector <1 x i32> %[[HIGH_INSERT]], {{.*}} zeroinitializer
+
+  // CHECK:      %[[SHUFFLE0:.*]] = shufflevector <1 x i32> %[[LOW_SHUFFLE]], <1 x i32> %[[HIGH_SHUFFLE]],
+  // CHECK-SAME: {{.*}} <i32 0, i32 1>
+  // CHECK:      bitcast <2 x i32> %[[SHUFFLE0]] to double
+  return asdouble(low, high);
+}
+
+// CHECK-LABEL: test_vuint
+double3 test_vuint(uint3 low, uint3 high) {
+  // CHECK:      %[[SHUFFLE1:.*]] = shufflevector
+  // CHECK-SAME: {{.*}} <i32 0, i32 3, i32 1, i32 4, i32 2, i32 5>
+  // CHECK:      bitcast <6 x i32> %[[SHUFFLE1]] to <3 x double>
+  return asdouble(low, high);
+}

>From 3030765d8362783687edc968aeeb9a5ea477b0b5 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Sat, 2 Nov 2024 02:26:04 +0000
Subject: [PATCH 4/4] [DXIL] Using a dx intrinsic for directx backend

---
 clang/lib/CodeGen/CGBuiltin.cpp               |  5 +++
 clang/test/CodeGenHLSL/builtins/asdouble.hlsl | 34 ++++++++++++-------
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  1 +
 llvm/lib/Target/DirectX/DXIL.td               | 10 ++++++
 .../DirectX/DirectXTargetTransformInfo.cpp    |  3 ++
 llvm/test/CodeGen/DirectX/asdouble.ll         | 22 ++++++++++++
 6 files changed, 62 insertions(+), 13 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/asdouble.ll

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index fee0d258366b5c..85ad203b50c7ab 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18648,6 +18648,11 @@ Value *handleAsDoubleBuiltin(CodeGenFunction &CGF, const CallExpr *E) {
     ResultType = llvm::FixedVectorType::get(CGF.DoubleTy, N);
   }
 
+  if (CGF.CGM.getTarget().getTriple().isDXIL())
+    return CGF.Builder.CreateIntrinsic(
+        /*ReturnType=*/ResultType, Intrinsic::dx_asdouble,
+        ArrayRef<Value *>{OpLowBits, OpHighBits}, nullptr, "hlsl.asdouble");
+
   if (!E->getArg(0)->getType()->isVectorType()) {
     OpLowBits = CGF.Builder.CreateVectorSplat(1, OpLowBits);
     OpHighBits = CGF.Builder.CreateVectorSplat(1, OpHighBits);
diff --git a/clang/test/CodeGenHLSL/builtins/asdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asdouble.hlsl
index 9d9ef048ee35fe..f1c31107cdcad6 100644
--- a/clang/test/CodeGenHLSL/builtins/asdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/asdouble.hlsl
@@ -1,29 +1,37 @@
 // RUN: %clang_cc1 -finclude-default-header -triple \
 // RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
-// RUN:   FileCheck %s
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
 // RUN: %clang_cc1 -finclude-default-header -triple \
 // RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
-// RUN:   FileCheck %s
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPV
 
 // Test lowering of asdouble expansion to shuffle/bitcast and splat when required
 
 // CHECK-LABEL: test_uint
 double test_uint(uint low, uint high) {
-  // CHECK: %[[LOW_INSERT:.*]] = insertelement <1 x i32>
-  // CHECK: %[[LOW_SHUFFLE:.*]] = shufflevector <1 x i32> %[[LOW_INSERT]], {{.*}} zeroinitializer
-  // CHECK: %[[HIGH_INSERT:.*]] = insertelement <1 x i32>
-  // CHECK: %[[HIGH_SHUFFLE:.*]] = shufflevector <1 x i32> %[[HIGH_INSERT]], {{.*}} zeroinitializer
-
-  // CHECK:      %[[SHUFFLE0:.*]] = shufflevector <1 x i32> %[[LOW_SHUFFLE]], <1 x i32> %[[HIGH_SHUFFLE]],
-  // CHECK-SAME: {{.*}} <i32 0, i32 1>
-  // CHECK:      bitcast <2 x i32> %[[SHUFFLE0]] to double
+  // CHECK-SPV: %[[LOW_INSERT:.*]] = insertelement <1 x i32>
+  // CHECK-SPV: %[[LOW_SHUFFLE:.*]] = shufflevector <1 x i32> %[[LOW_INSERT]], {{.*}} zeroinitializer
+  // CHECK-SPV: %[[HIGH_INSERT:.*]] = insertelement <1 x i32>
+  // CHECK-SPV: %[[HIGH_SHUFFLE:.*]] = shufflevector <1 x i32> %[[HIGH_INSERT]], {{.*}} zeroinitializer
+
+  // CHECK-SPV:      %[[SHUFFLE0:.*]] = shufflevector <1 x i32> %[[LOW_SHUFFLE]], <1 x i32> %[[HIGH_SHUFFLE]],
+  // CHECK-SPV-SAME: {{.*}} <i32 0, i32 1>
+  // CHECK-SPV:      bitcast <2 x i32> %[[SHUFFLE0]] to double
+
+  // CHECK-DXIL: call double @llvm.dx.asdouble.i32
   return asdouble(low, high);
 }
 
+// CHECK-DXIL: declare double @llvm.dx.asdouble.i32
+
 // CHECK-LABEL: test_vuint
 double3 test_vuint(uint3 low, uint3 high) {
-  // CHECK:      %[[SHUFFLE1:.*]] = shufflevector
-  // CHECK-SAME: {{.*}} <i32 0, i32 3, i32 1, i32 4, i32 2, i32 5>
-  // CHECK:      bitcast <6 x i32> %[[SHUFFLE1]] to <3 x double>
+  // CHECK-SPV:      %[[SHUFFLE1:.*]] = shufflevector
+  // CHECK-SPV-SAME: {{.*}} <i32 0, i32 3, i32 1, i32 4, i32 2, i32 5>
+  // CHECK-SPV:      bitcast <6 x i32> %[[SHUFFLE1]] to <3 x double>
+
+  // CHECK-DXIL: call <3 x double> @llvm.dx.asdouble.v3i32
   return asdouble(low, high);
 }
+
+// CHECK-DXIL: declare <3 x double> @llvm.dx.asdouble.v3i32
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index e30d37f69f781e..904607a98aa86e 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -40,6 +40,7 @@ def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
 
 def int_dx_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>;
+def int_dx_asdouble : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [llvm_anyint_ty, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..6a1edb9f6debe5 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -779,6 +779,16 @@ def FlattenedThreadIdInGroup :  DXILOp<96, flattenedThreadIdInGroup> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
+
+def MakeDouble :  DXILOp<101, makeDouble> {
+  let Doc = "creates a double value";
+  let LLVMIntrinsic = int_dx_asdouble;
+  let arguments = [Int32Ty, Int32Ty];
+  let result = DoubleTy;
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
 def AnnotateHandle : DXILOp<217, annotateHandle> {
   let Doc = "annotate handle with resource properties";
   let arguments = [HandleTy, ResPropsTy];
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 2be5fc4de2409c..a115a664209445 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -28,6 +28,8 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
 bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(
     Intrinsic::ID ID, unsigned ScalarOpdIdx, bool Default) {
   switch (ID) {
+  case Intrinsic::dx_asdouble:
+    return ScalarOpdIdx == 0;
   default:
     return Default;
   }
@@ -39,6 +41,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
   case Intrinsic::dx_frac:
   case Intrinsic::dx_rsqrt:
   case Intrinsic::dx_wave_readlane:
+  case Intrinsic::dx_asdouble:
   case Intrinsic::dx_splitdouble:
     return true;
   default:
diff --git a/llvm/test/CodeGen/DirectX/asdouble.ll b/llvm/test/CodeGen/DirectX/asdouble.ll
new file mode 100644
index 00000000000000..6a581d69eb7e9d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/asdouble.ll
@@ -0,0 +1,22 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; Test that for scalar and vector inputs, asdouble maps down to the makeDouble
+; DirectX op
+
+define noundef double @asdouble_scalar(i32 noundef %low, i32 noundef %high) {
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low, i32 %high)
+  %ret = call double @llvm.dx.asdouble.i32(i32 %low, i32 %high)
+  ret double %ret
+}
+
+declare double @llvm.dx.asdouble.i32(i32, i32)
+
+define noundef <3 x double> @asdouble_vec(<3 x i32> noundef %low, <3 x i32> noundef %high) {
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i0, i32 %high.i0)
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i1, i32 %high.i1)
+; CHECK: call double @dx.op.makeDouble(i32 101, i32 %low.i2, i32 %high.i2)
+  %ret = call <3 x double> @llvm.dx.asdouble.v3i32(<3 x i32> %low, <3 x i32> %high)
+  ret <3 x double> %ret
+}
+
+declare <3 x double> @llvm.dx.asdouble.v3i32(<3 x i32>, <3 x i32>)



More information about the cfe-commits mailing list