[clang] [llvm] Adding splitdouble HLSL function (PR #109331)

via cfe-commits cfe-commits at lists.llvm.org
Mon Oct 28 09:32:45 PDT 2024


https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/109331

>From 7eacb87b6903996aaa044854a83f404c68907e06 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 19 Sep 2024 00:13:51 +0000
Subject: [PATCH 01/18] Codegen builtin

---
 clang/include/clang/Basic/Builtins.td         |  6 ++
 clang/lib/CodeGen/CGBuiltin.cpp               | 38 ++++++++++++
 clang/lib/CodeGen/CGCall.cpp                  |  5 ++
 clang/lib/CodeGen/CGExpr.cpp                  | 15 ++++-
 clang/lib/CodeGen/CodeGenFunction.h           | 10 +++-
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      | 20 +++++++
 clang/lib/Sema/SemaHLSL.cpp                   | 58 ++++++++++++++++---
 .../builtins/asuint-splitdouble.hlsl          | 10 ++++
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  5 ++
 llvm/lib/Target/DirectX/DXIL.td               |  1 +
 .../Target/DirectX/DXILIntrinsicExpansion.cpp |  1 +
 11 files changed, 155 insertions(+), 14 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..1bff7e6838836c 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4871,6 +4871,12 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLAsUintSplitDouble: LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_asuint_splitdouble"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e2d03eff8ab4a0..70fd0261bfe150 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18959,6 +18959,44 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
         nullptr, "hlsl.radians");
   }
+  // This should only be called when targeting DXIL
+  case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+
+    assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
+            E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
+            E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
+           "asuint operands types mismatch");
+
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
+    const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
+
+    CallArgList Args;
+    LValue Op1TmpLValue = EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+    LValue Op2TmpLValue = EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+
+    llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty);
+    if (Op0->getType()->isVectorTy()) {
+      auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
+
+      llvm::VectorType *i32VecTy = llvm::VectorType::get(
+          Int32Ty, ElementCount::getFixed(XVecTy->getNumElements()));
+
+      retType = llvm::StructType::get(i32VecTy, i32VecTy);
+    }
+
+    CallInst *CI =
+        Builder.CreateIntrinsic(retType, llvm::Intrinsic::dx_asuint_splitdouble,
+                                {Op0}, nullptr, "hlsl.asuint");
+
+    Value *arg0 = Builder.CreateExtractValue(CI, 0);
+    Value *arg1 = Builder.CreateExtractValue(CI, 1);
+
+    Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+    auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
+    EmitWritebacks(*this, Args);
+    return s;
+  }
   }
   return nullptr;
 }
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 1949b4ceb7f204..ebdb04331ea6fb 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4681,6 +4681,11 @@ void CallArg::copyInto(CodeGenFunction &CGF, Address Addr) const {
   IsUsed = true;
 }
 
+void CodeGenFunction::EmitWritebacks(CodeGenFunction &CGF,
+                                     const CallArgList &args) {
+  emitWritebacks(CGF, args);
+}
+
 void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
                                   QualType type) {
   DisableDebugLocationUpdates Dis(*this, E);
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index e0ea65bcaf3637..9f6743844dfba8 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -19,6 +19,7 @@
 #include "CGObjCRuntime.h"
 #include "CGOpenMPRuntime.h"
 #include "CGRecordLayout.h"
+#include "CGValue.h"
 #include "CodeGenFunction.h"
 #include "CodeGenModule.h"
 #include "ConstantEmitter.h"
@@ -28,6 +29,7 @@
 #include "clang/AST/DeclObjC.h"
 #include "clang/AST/NSAPI.h"
 #include "clang/AST/StmtVisitor.h"
+#include "clang/AST/Type.h"
 #include "clang/Basic/Builtins.h"
 #include "clang/Basic/CodeGenOptions.h"
 #include "clang/Basic/SourceManager.h"
@@ -5460,9 +5462,8 @@ LValue CodeGenFunction::EmitOpaqueValueLValue(const OpaqueValueExpr *e) {
   return getOrCreateOpaqueLValueMapping(e);
 }
 
-void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
-                                         CallArgList &Args, QualType Ty) {
-
+std::pair<LValue, LValue>
+CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) {
   // Emitting the casted temporary through an opaque value.
   LValue BaseLV = EmitLValue(E->getArgLValue());
   OpaqueValueMappingData::bind(*this, E->getOpaqueArgLValue(), BaseLV);
@@ -5476,6 +5477,13 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
                                TempLV);
 
   OpaqueValueMappingData::bind(*this, E->getCastedTemporary(), TempLV);
+  return std::make_pair(BaseLV, TempLV);
+}
+
+LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
+                                           CallArgList &Args, QualType Ty) {
+
+  auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty);
 
   llvm::Value *Addr = TempLV.getAddress().getBasePointer();
   llvm::Type *ElTy = ConvertTypeForMem(TempLV.getType());
@@ -5488,6 +5496,7 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
   Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(),
                     LifetimeSize);
   Args.add(RValue::get(TmpAddr, *this), Ty);
+  return TempLV;
 }
 
 LValue
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 750a6cc24badca..22a6810ff5bc7a 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4296,8 +4296,11 @@ class CodeGenFunction : public CodeGenTypeCache {
   LValue EmitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *E);
   LValue EmitOpaqueValueLValue(const OpaqueValueExpr *e);
   LValue EmitHLSLArrayAssignLValue(const BinaryOperator *E);
-  void EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
-                          QualType Ty);
+
+  std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E,
+                                                  QualType Ty);
+  LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
+                            QualType Ty);
 
   Address EmitExtVectorElementLValue(LValue V);
 
@@ -5147,6 +5150,9 @@ class CodeGenFunction : public CodeGenTypeCache {
                            SourceLocation ArgLoc, AbstractCallee AC,
                            unsigned ParmNum);
 
+  /// EmitWriteback - Emit callbacks for function.
+  void EmitWritebacks(CodeGenFunction &CGF, const CallArgList &args);
+
   /// EmitCallArg - Emit a single call argument.
   void EmitCallArg(CallArgList &args, const Expr *E, QualType ArgType);
 
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..406d5e888fdd64 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -438,6 +438,26 @@ template <typename T> constexpr uint asuint(T F) {
   return __detail::bit_cast<uint, T>(F);
 }
 
+//===----------------------------------------------------------------------===//
+// asuint splitdouble builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn void asuint(double D, out uint lowbits, out int highbits)
+/// \brief Split and interprets the lowbits and highbits of double D into uints.
+/// \param D The input double.
+/// \param lowbits The output lowbits of D.
+/// \param highbits The highbits lowbits D.
+#if __is_target_arch(dxil)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double, out uint, out uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double2, out uint2, out uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double3, out uint3, out uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double4, out uint4, out uint4);
+#endif
+
 //===----------------------------------------------------------------------===//
 // atan builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 1f6c5b8d4561bc..ac7b9bc9862abb 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1698,18 +1698,27 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
   return true;
 }
 
-static bool CheckArgsTypesAreCorrect(
+bool CheckArgTypeIsCorrect(
+    Sema *S, Expr *Arg, QualType ExpectedType,
+    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
+  QualType PassedType = Arg->getType();
+  if (Check(PassedType)) {
+    if (auto *VecTyA = PassedType->getAs<VectorType>())
+      ExpectedType = S->Context.getVectorType(
+          ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
+    S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
+        << PassedType << ExpectedType << 1 << 0 << 0;
+    return true;
+  }
+  return false;
+}
+
+bool CheckArgsTypesAreCorrect(
     Sema *S, CallExpr *TheCall, QualType ExpectedType,
     llvm::function_ref<bool(clang::QualType PassedType)> Check) {
   for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
-    QualType PassedType = TheCall->getArg(i)->getType();
-    if (Check(PassedType)) {
-      if (auto *VecTyA = PassedType->getAs<VectorType>())
-        ExpectedType = S->Context.getVectorType(
-            ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
-      S->Diag(TheCall->getArg(0)->getBeginLoc(),
-              diag::err_typecheck_convert_incompatible)
-          << PassedType << ExpectedType << 1 << 0 << 0;
+    Expr *Arg = TheCall->getArg(i);
+    if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
       return true;
     }
   }
@@ -2074,6 +2083,37 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+    if (SemaRef.checkArgCount(TheCall, 3))
+      return true;
+
+    // Expr *Op0 = TheCall->getArg(0);
+
+    // auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool {
+    //   return !PassedType->isDoubleType();
+    // };
+
+    // if (CheckArgTypeIsCorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
+    //                           CheckIsNotDouble)) {
+    //   return true;
+    // }
+
+    // Expr *Op1 = TheCall->getArg(1);
+    // Expr *Op2 = TheCall->getArg(2);
+
+    // auto CheckIsNotUint = [](clang::QualType PassedType) -> bool {
+    //   return !PassedType->isUnsignedIntegerType();
+    // };
+
+    // if (CheckArgTypeIsCorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
+    //                           CheckIsNotUint) ||
+    //     CheckArgTypeIsCorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
+    //                           CheckIsNotUint)) {
+    //   return true;
+    // }
+
+    break;
+  }
   case Builtin::BI__builtin_elementwise_acos:
   case Builtin::BI__builtin_elementwise_asin:
   case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
new file mode 100644
index 00000000000000..e359354dc3a6df
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
@@ -0,0 +1,10 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O0 -o - | FileCheck %s
+
+// CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}} 
+// CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}}
+// CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]])
+float fn(double D) {
+  uint A, B;
+  asuint(D, A, B);
+  return A + B;
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index e30d37f69f781e..197a50ac585798 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -92,4 +92,9 @@ def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, L
 def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>], 
     [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
 def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+
+def int_dx_asuint_splitdouble : DefaultAttrsIntrinsic<
+    [llvm_anyint_ty, LLVMMatchType<0>], 
+    [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], 
+    [IntrNoMem, IntrWillReturn]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..019f0b36cf4ff0 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -778,6 +778,7 @@ def FlattenedThreadIdInGroup :  DXILOp<96, flattenedThreadIdInGroup> {
   let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
+//
 
 def AnnotateHandle : DXILOp<217, annotateHandle> {
   let Doc = "annotate handle with resource properties";
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index fb5383b3514a5a..fc2755dfb24252 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -12,6 +12,7 @@
 
 #include "DXILIntrinsicExpansion.h"
 #include "DirectX.h"
+#include "llvm-c/Core.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/DXILResource.h"

>From 0057d1306d0259ddd29f3bb01b592d9bbd7e0aa1 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Mon, 23 Sep 2024 21:19:12 +0000
Subject: [PATCH 02/18] adding vector case for splitdouble

---
 clang/lib/CodeGen/CGBuiltin.cpp               | 62 ++++++++++++++-----
 clang/lib/CodeGen/CGExpr.cpp                  |  8 ++-
 clang/lib/CodeGen/CodeGenFunction.h           |  4 +-
 .../builtins/asuint-splitdouble.hlsl          |  4 +-
 4 files changed, 57 insertions(+), 21 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 70fd0261bfe150..7c3fcb5a7b1ba2 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -34,12 +34,14 @@
 #include "clang/Frontend/FrontendDiagnostic.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/FloatingPointMode.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsAArch64.h"
 #include "llvm/IR/IntrinsicsAMDGPU.h"
@@ -67,6 +69,7 @@
 #include "llvm/TargetParser/X86TargetParser.h"
 #include <optional>
 #include <sstream>
+#include <utility>
 
 using namespace clang;
 using namespace CodeGen;
@@ -18971,29 +18974,60 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
     const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
 
+    auto emitSplitDouble =
+        [](CGBuilderTy *Builder, llvm::Value *arg,
+           llvm::Type *retType) -> std::pair<Value *, Value *> {
+      CallInst *CI = Builder->CreateIntrinsic(
+          retType, llvm::Intrinsic::dx_asuint_splitdouble, {arg}, nullptr,
+          "hlsl.asuint");
+
+      Value *arg0 = Builder->CreateExtractValue(CI, 0);
+      Value *arg1 = Builder->CreateExtractValue(CI, 1);
+
+      return std::make_pair(arg0, arg1);
+    };
+
     CallArgList Args;
-    LValue Op1TmpLValue = EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
-    LValue Op2TmpLValue = EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+    auto [Op1BaseLValue, Op1TmpLValue] =
+        EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+    auto [Op2BaseLValue, Op2TmpLValue] =
+        EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
 
     llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty);
-    if (Op0->getType()->isVectorTy()) {
-      auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
 
-      llvm::VectorType *i32VecTy = llvm::VectorType::get(
-          Int32Ty, ElementCount::getFixed(XVecTy->getNumElements()));
+    if (!Op0->getType()->isVectorTy()) {
+      auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType);
+
+      Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+      auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
 
-      retType = llvm::StructType::get(i32VecTy, i32VecTy);
+      EmitWritebacks(*this, Args);
+      return s;
     }
 
-    CallInst *CI =
-        Builder.CreateIntrinsic(retType, llvm::Intrinsic::dx_asuint_splitdouble,
-                                {Op0}, nullptr, "hlsl.asuint");
+    auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
+
+    llvm::VectorType *i32VecTy = llvm::VectorType::get(
+        Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
 
-    Value *arg0 = Builder.CreateExtractValue(CI, 0);
-    Value *arg1 = Builder.CreateExtractValue(CI, 1);
+    std::pair<Value *, Value *> inserts = std::make_pair(nullptr, nullptr);
+
+    for (uint64_t idx = 0; idx < Op0VecTy->getNumElements(); idx++) {
+      Value *op = Builder.CreateExtractElement(Op0, idx);
+
+      auto [arg0, arg1] = emitSplitDouble(&Builder, op, retType);
+
+      if (idx == 0) {
+        inserts.first = Builder.CreateInsertElement(i32VecTy, arg0, idx);
+        inserts.second = Builder.CreateInsertElement(i32VecTy, arg1, idx);
+      } else {
+        inserts.first = Builder.CreateInsertElement(inserts.first, arg0, idx);
+        inserts.second = Builder.CreateInsertElement(inserts.second, arg0, idx);
+      }
+    }
 
-    Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
-    auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
+    Builder.CreateStore(inserts.first, Op1TmpLValue.getAddress());
+    auto *s = Builder.CreateStore(inserts.second, Op2TmpLValue.getAddress());
     EmitWritebacks(*this, Args);
     return s;
   }
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 9f6743844dfba8..83c7f6537d22fb 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -54,6 +54,7 @@
 
 #include <optional>
 #include <string>
+#include <utility>
 
 using namespace clang;
 using namespace CodeGen;
@@ -5480,8 +5481,9 @@ CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) {
   return std::make_pair(BaseLV, TempLV);
 }
 
-LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
-                                           CallArgList &Args, QualType Ty) {
+std::pair<LValue, LValue>
+CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
+                                    QualType Ty) {
 
   auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty);
 
@@ -5496,7 +5498,7 @@ LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
   Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(),
                     LifetimeSize);
   Args.add(RValue::get(TmpAddr, *this), Ty);
-  return TempLV;
+  return std::make_pair(BaseLV, TempLV);
 }
 
 LValue
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 22a6810ff5bc7a..c1d6b4e2e3460d 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4299,8 +4299,8 @@ class CodeGenFunction : public CodeGenTypeCache {
 
   std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E,
                                                   QualType Ty);
-  LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
-                            QualType Ty);
+  std::pair<LValue, LValue> EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
+                                               CallArgList &Args, QualType Ty);
 
   Address EmitExtVectorElementLValue(LValue V);
 
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
index e359354dc3a6df..4326612db96b0f 100644
--- a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
@@ -3,8 +3,8 @@
 // CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}} 
 // CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}}
 // CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]])
-float fn(double D) {
-  uint A, B;
+float2 fn(double2 D) {
+  uint2 A, B;
   asuint(D, A, B);
   return A + B;
 }

>From c359f18ccd5b1b8d761d275504a3d6d3fe674cc0 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Tue, 24 Sep 2024 00:50:10 +0000
Subject: [PATCH 03/18] adding lowering to dxil

---
 clang/include/clang/Basic/Builtins.td         | 16 +++++-
 clang/lib/CodeGen/CGBuiltin.cpp               | 15 +++---
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  8 +--
 clang/lib/Sema/SemaHLSL.cpp                   | 44 ++++++++--------
 .../builtins/asuint-splitdouble.hlsl          | 25 +++++++---
 .../test/SemaHLSL/BuiltIns/asuint-errors.hlsl |  4 ++
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  2 +-
 llvm/lib/Target/DirectX/DXIL.td               | 11 +++-
 .../Target/DirectX/DXILIntrinsicExpansion.cpp |  1 -
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp     | 13 +++++
 llvm/lib/Target/DirectX/DXILOpBuilder.h       |  4 ++
 llvm/lib/Target/DirectX/DXILOpLowering.cpp    | 50 +++++++++++++++++++
 12 files changed, 147 insertions(+), 46 deletions(-)

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 1bff7e6838836c..b35a24eda6da44 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4871,8 +4871,20 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
-def HLSLAsUintSplitDouble: LangBuiltin<"HLSL_LANG"> {
-  let Spellings = ["__builtin_hlsl_asuint_splitdouble"];
+def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_elementwise_radians"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
+def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_elementwise_radians"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
+def HLSLSplitDouble: LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_splitdouble"];
   let Attributes = [NoThrow, Const];
   let Prototype = "void(...)";
 }
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7c3fcb5a7b1ba2..7e442c361d2c32 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -34,14 +34,12 @@
 #include "clang/Frontend/FrontendDiagnostic.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
-#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/FloatingPointMode.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/InlineAsm.h"
-#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsAArch64.h"
 #include "llvm/IR/IntrinsicsAMDGPU.h"
@@ -69,7 +67,6 @@
 #include "llvm/TargetParser/X86TargetParser.h"
 #include <optional>
 #include <sstream>
-#include <utility>
 
 using namespace clang;
 using namespace CodeGen;
@@ -18963,7 +18960,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         nullptr, "hlsl.radians");
   }
   // This should only be called when targeting DXIL
-  case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+  case Builtin::BI__builtin_hlsl_splitdouble: {
 
     assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
             E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
@@ -18977,9 +18974,9 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     auto emitSplitDouble =
         [](CGBuilderTy *Builder, llvm::Value *arg,
            llvm::Type *retType) -> std::pair<Value *, Value *> {
-      CallInst *CI = Builder->CreateIntrinsic(
-          retType, llvm::Intrinsic::dx_asuint_splitdouble, {arg}, nullptr,
-          "hlsl.asuint");
+      CallInst *CI =
+          Builder->CreateIntrinsic(retType, llvm::Intrinsic::dx_splitdouble,
+                                   {arg}, nullptr, "hlsl.asuint");
 
       Value *arg0 = Builder->CreateExtractValue(CI, 0);
       Value *arg1 = Builder->CreateExtractValue(CI, 1);
@@ -18993,7 +18990,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     auto [Op2BaseLValue, Op2TmpLValue] =
         EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
 
-    llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty);
+    llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
 
     if (!Op0->getType()->isVectorTy()) {
       auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType);
@@ -19022,7 +19019,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         inserts.second = Builder.CreateInsertElement(i32VecTy, arg1, idx);
       } else {
         inserts.first = Builder.CreateInsertElement(inserts.first, arg0, idx);
-        inserts.second = Builder.CreateInsertElement(inserts.second, arg0, idx);
+        inserts.second = Builder.CreateInsertElement(inserts.second, arg1, idx);
       }
     }
 
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 406d5e888fdd64..ef59ac777796f5 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -448,13 +448,13 @@ template <typename T> constexpr uint asuint(T F) {
 /// \param lowbits The output lowbits of D.
 /// \param highbits The highbits lowbits D.
 #if __is_target_arch(dxil)
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
 void asuint(double, out uint, out uint);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
 void asuint(double2, out uint2, out uint2);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
 void asuint(double3, out uint3, out uint3);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
 void asuint(double4, out uint4, out uint4);
 #endif
 
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index ac7b9bc9862abb..993da91817f774 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1698,7 +1698,7 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
   return true;
 }
 
-bool CheckArgTypeIsCorrect(
+bool CheckArgTypeIsIncorrect(
     Sema *S, Expr *Arg, QualType ExpectedType,
     llvm::function_ref<bool(clang::QualType PassedType)> Check) {
   QualType PassedType = Arg->getType();
@@ -1718,7 +1718,7 @@ bool CheckArgsTypesAreCorrect(
     llvm::function_ref<bool(clang::QualType PassedType)> Check) {
   for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
     Expr *Arg = TheCall->getArg(i);
-    if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
+    if (CheckArgTypeIsIncorrect(S, Arg, ExpectedType, Check)) {
       return true;
     }
   }
@@ -2083,34 +2083,34 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
-  case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+  case Builtin::BI__builtin_hlsl_splitdouble: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
 
-    // Expr *Op0 = TheCall->getArg(0);
+    Expr *Op0 = TheCall->getArg(0);
 
-    // auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool {
-    //   return !PassedType->isDoubleType();
-    // };
+    auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool {
+      return !PassedType->hasFloatingRepresentation();
+    };
 
-    // if (CheckArgTypeIsCorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
-    //                           CheckIsNotDouble)) {
-    //   return true;
-    // }
+    if (CheckArgTypeIsIncorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
+                                CheckIsNotDouble)) {
+      return true;
+    }
 
-    // Expr *Op1 = TheCall->getArg(1);
-    // Expr *Op2 = TheCall->getArg(2);
+    Expr *Op1 = TheCall->getArg(1);
+    Expr *Op2 = TheCall->getArg(2);
 
-    // auto CheckIsNotUint = [](clang::QualType PassedType) -> bool {
-    //   return !PassedType->isUnsignedIntegerType();
-    // };
+    auto CheckIsNotUint = [](clang::QualType PassedType) -> bool {
+      return !PassedType->hasUnsignedIntegerRepresentation();
+    };
 
-    // if (CheckArgTypeIsCorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
-    //                           CheckIsNotUint) ||
-    //     CheckArgTypeIsCorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
-    //                           CheckIsNotUint)) {
-    //   return true;
-    // }
+    if (CheckArgTypeIsIncorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
+                                CheckIsNotUint) ||
+        CheckArgTypeIsIncorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
+                                CheckIsNotUint)) {
+      return true;
+    }
 
     break;
   }
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
index 4326612db96b0f..1711c344792aee 100644
--- a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
@@ -1,10 +1,23 @@
-// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O0 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
 
-// CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}} 
-// CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}}
-// CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]])
-float2 fn(double2 D) {
-  uint2 A, B;
+
+// CHECK: define {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+float test_scalar(double D) {
+  uint A, B;
+  asuint(D, A, B);
+  return A + B;
+}
+
+// CHECK: define {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[VALD]], i64 [[VALIDX:[0-3]]]
+// CHECK-NEXT: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALREG]])
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+float3 test_vector(double3 D) {
+  uint3 A, B;
   asuint(D, A, B);
   return A + B;
 }
diff --git a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
index 8c56fdddb1c24c..b9a920f9f1b4d0 100644
--- a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
@@ -6,6 +6,10 @@ uint4 test_asuint_too_many_arg(float p0, float p1) {
   // expected-error at -1 {{no matching function for call to 'asuint'}}
   // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires single argument 'V', but 2 arguments were provided}}
   // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires single argument 'F', but 2 arguments were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
 }
 
 uint test_asuint_double(double p1) {
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 197a50ac585798..996109ed352c7e 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -93,7 +93,7 @@ def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>
     [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
 def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 
-def int_dx_asuint_splitdouble : DefaultAttrsIntrinsic<
+def int_dx_splitdouble : DefaultAttrsIntrinsic<
     [llvm_anyint_ty, LLVMMatchType<0>], 
     [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], 
     [IntrNoMem, IntrWillReturn]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 019f0b36cf4ff0..338cc546348b8d 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -47,6 +47,7 @@ def ResRetInt32Ty : DXILOpParamType;
 def HandleTy : DXILOpParamType;
 def ResBindTy : DXILOpParamType;
 def ResPropsTy : DXILOpParamType;
+def ResSplitDoubleTy : DXILOpParamType;
 
 class DXILOpClass;
 
@@ -778,7 +779,15 @@ def FlattenedThreadIdInGroup :  DXILOp<96, flattenedThreadIdInGroup> {
   let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
-//
+
+def SplitDouble :  DXILOp<102, splitDouble> {
+  let Doc = "Splits a double into 2 uints";
+  let arguments = [OverloadTy];
+  let result = ResSplitDoubleTy;
+  let overloads = [Overloads<DXIL1_0, [DoubleTy]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
 
 def AnnotateHandle : DXILOp<217, annotateHandle> {
   let Doc = "annotate handle with resource properties";
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index fc2755dfb24252..fb5383b3514a5a 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -12,7 +12,6 @@
 
 #include "DXILIntrinsicExpansion.h"
 #include "DirectX.h"
-#include "llvm-c/Core.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/DXILResource.h"
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 7719d6b1079110..982d7849d9bb8b 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -229,6 +229,13 @@ static StructType *getResPropsType(LLVMContext &Context) {
   return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties");
 }
 
+static StructType *getResSplitDoubleType(LLVMContext &Context) {
+  if (auto *ST = StructType::getTypeByName(Context, "dx.types.splitdouble"))
+    return ST;
+  Type *Int32Ty = Type::getInt32Ty(Context);
+  return StructType::create({Int32Ty, Int32Ty}, "dx.types.splitdouble");
+}
+
 static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
                                     Type *OverloadTy) {
   switch (Kind) {
@@ -266,6 +273,8 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
     return getResBindType(Ctx);
   case OpParamType::ResPropsTy:
     return getResPropsType(Ctx);
+  case OpParamType::ResSplitDoubleTy:
+    return getResSplitDoubleType(Ctx);
   }
   llvm_unreachable("Invalid parameter kind");
   return nullptr;
@@ -467,6 +476,10 @@ StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
   return ::getResRetType(ElementTy);
 }
 
+StructType *DXILOpBuilder::getResSplitDoubleType(LLVMContext &Context) {
+  return ::getResSplitDoubleType(Context);
+}
+
 StructType *DXILOpBuilder::getHandleType() {
   return ::getHandleType(IRB.getContext());
 }
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 037ae3822cfb90..8b1e87c283146c 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -49,6 +49,10 @@ class DXILOpBuilder {
 
   /// Get a `%dx.types.ResRet` type with the given element type.
   StructType *getResRetType(Type *ElementTy);
+
+  /// Get the `%dx.types.splitdouble` type.
+  StructType *getResSplitDoubleType(LLVMContext &Context);
+
   /// Get the `%dx.types.Handle` type.
   StructType *getHandleType();
 
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index c62ba8c21d6791..28e1cb2ce2e19f 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -17,6 +17,7 @@
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsDirectX.h"
 #include "llvm/IR/Module.h"
@@ -263,6 +264,31 @@ class OpLowerer {
     return lowerToBindAndAnnotateHandle(F);
   }
 
+  Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
+    IRBuilder<> &IRB = OpBuilder.getIRB();
+
+    for (Use &U : Intrin->uses()) {
+      if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
+
+        assert(EVI->getNumIndices() == 1 &&
+               "splitdouble result should be indexed individually.");
+        if (EVI->getNumIndices() != 1)
+          return make_error<StringError>(
+              "splitdouble result should be indexed individually.",
+              inconvertibleErrorCode());
+
+        unsigned int IndexVal = EVI->getIndices()[0];
+
+        auto *OpEVI = IRB.CreateExtractValue(Op, IndexVal);
+        EVI->replaceAllUsesWith(OpEVI);
+        EVI->eraseFromParent();
+      }
+    }
+    Intrin->eraseFromParent();
+
+    return Error::success();
+  }
+
   /// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
   /// Since we expect to be post-scalarization, make an effort to avoid vectors.
   Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
@@ -460,6 +486,27 @@ class OpLowerer {
     });
   }
 
+  [[nodiscard]] bool lowerSplitDouble(Function &F) {
+    IRBuilder<> &IRB = OpBuilder.getIRB();
+    return replaceFunction(F, [&](CallInst *CI) -> Error {
+      IRB.SetInsertPoint(CI);
+
+      Value *Arg0 = CI->getArgOperand(0);
+
+      Type *NewRetTy = OpBuilder.getResSplitDoubleType(M.getContext());
+
+      std::array<Value *, 1> Args{Arg0};
+      Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
+          OpCode::SplitDouble, Args, CI->getName(), NewRetTy);
+      if (Error E = OpCall.takeError())
+        return E;
+      if (Error E = replaceSplitDoubleCallUsages(CI, *OpCall))
+        return E;
+
+      return Error::success();
+    });
+  }
+
   bool lowerIntrinsics() {
     bool Updated = false;
     bool HasErrors = false;
@@ -488,6 +535,9 @@ class OpLowerer {
       case Intrinsic::dx_typedBufferStore:
         HasErrors |= lowerTypedBufferStore(F);
         break;
+      case Intrinsic::dx_splitdouble:
+        HasErrors |= lowerSplitDouble(F);
+        break;
       }
       Updated = true;
     }

>From fd6ea3d34e2930e2d642f241c2e85a74270e1dc6 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 26 Sep 2024 00:20:54 +0000
Subject: [PATCH 04/18] adding tests

---
 clang/lib/CodeGen/CGExpr.cpp                  |  3 -
 ...uint-splitdouble.hlsl => splitdouble.hlsl} |  0
 llvm/lib/Target/DirectX/DXILOpLowering.cpp    |  2 +-
 llvm/test/CodeGen/DirectX/splitdouble.ll      | 63 +++++++++++++++++++
 4 files changed, 64 insertions(+), 4 deletions(-)
 rename clang/test/CodeGenHLSL/builtins/{asuint-splitdouble.hlsl => splitdouble.hlsl} (100%)
 create mode 100644 llvm/test/CodeGen/DirectX/splitdouble.ll

diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 83c7f6537d22fb..35e6aa8cec6e7e 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -19,7 +19,6 @@
 #include "CGObjCRuntime.h"
 #include "CGOpenMPRuntime.h"
 #include "CGRecordLayout.h"
-#include "CGValue.h"
 #include "CodeGenFunction.h"
 #include "CodeGenModule.h"
 #include "ConstantEmitter.h"
@@ -29,7 +28,6 @@
 #include "clang/AST/DeclObjC.h"
 #include "clang/AST/NSAPI.h"
 #include "clang/AST/StmtVisitor.h"
-#include "clang/AST/Type.h"
 #include "clang/Basic/Builtins.h"
 #include "clang/Basic/CodeGenOptions.h"
 #include "clang/Basic/SourceManager.h"
@@ -54,7 +52,6 @@
 
 #include <optional>
 #include <string>
-#include <utility>
 
 using namespace clang;
 using namespace CodeGen;
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
similarity index 100%
rename from clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
rename to clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 28e1cb2ce2e19f..2b62575ea77821 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -267,7 +267,7 @@ class OpLowerer {
   Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
     IRBuilder<> &IRB = OpBuilder.getIRB();
 
-    for (Use &U : Intrin->uses()) {
+    for (Use &U : make_early_inc_range(Intrin->uses())) {
       if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
 
         assert(EVI->getNumIndices() == 1 &&
diff --git a/llvm/test/CodeGen/DirectX/splitdouble.ll b/llvm/test/CodeGen/DirectX/splitdouble.ll
new file mode 100644
index 00000000000000..3ada8c07325431
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/splitdouble.ll
@@ -0,0 +1,63 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; ModuleID = '../clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl'
+source_filename = "../clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl"
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxilv1.3-pc-shadermodel6.3-library"
+
+; Function Attrs: alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none)
+define noundef float @"?test_scalar@@YAMN at Z"(double noundef %D) local_unnamed_addr #0 {
+entry:
+  ; CHECK: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double %D)
+  ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
+  ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
+  %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
+  %0 = extractvalue { i32, i32 } %hlsl.asuint, 0
+  %1 = extractvalue { i32, i32 } %hlsl.asuint, 1
+  %add = add i32 %0, %1
+  %conv = uitofp i32 %add to float
+  ret float %conv
+}
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare { i32, i32 } @llvm.dx.splitdouble.i32(double) #1
+
+; Function Attrs: alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none)
+define noundef <3 x float> @"?test_vector@@YAT?$__vector at M$02 at __clang@@T?$__vector at N$02 at 2@@Z"(<3 x double> noundef %D) local_unnamed_addr #0 {
+entry:
+  %0 = extractelement <3 x double> %D, i64 0
+  ; CHECK-COUNT-3: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double {{.*}})
+  ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
+  ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
+  %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %0)
+  %1 = extractvalue { i32, i32 } %hlsl.asuint, 0
+  %2 = extractvalue { i32, i32 } %hlsl.asuint, 1
+  %3 = insertelement <3 x i32> poison, i32 %1, i64 0
+  %4 = insertelement <3 x i32> poison, i32 %2, i64 0
+  %5 = extractelement <3 x double> %D, i64 1
+  %hlsl.asuint2 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %5)
+  %6 = extractvalue { i32, i32 } %hlsl.asuint2, 0
+  %7 = extractvalue { i32, i32 } %hlsl.asuint2, 1
+  %8 = insertelement <3 x i32> %3, i32 %6, i64 1
+  %9 = insertelement <3 x i32> %4, i32 %7, i64 1
+  %10 = extractelement <3 x double> %D, i64 2
+  %hlsl.asuint3 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %10)
+  %11 = extractvalue { i32, i32 } %hlsl.asuint3, 0
+  %12 = extractvalue { i32, i32 } %hlsl.asuint3, 1
+  %13 = insertelement <3 x i32> %8, i32 %11, i64 2
+  %14 = insertelement <3 x i32> %9, i32 %12, i64 2
+  %add = add <3 x i32> %13, %14
+  %conv = uitofp <3 x i32> %add to <3 x float>
+  ret <3 x float> %conv
+}
+
+attributes #0 = { alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none) "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { mustprogress nocallback nofree nosync nounwind willreturn memory(none) }
+
+!llvm.module.flags = !{!0}
+!dx.valver = !{!1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 1, i32 8}
+!2 = !{!"clang version 20.0.0git (https://github.com/joaosaffran/llvm-project.git 81476c7ad27010600dc4b4be1d66e7c7db7c10fb)"}

>From b4a5ab557c2c98f99f0ba3f7250c645167de4526 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 26 Sep 2024 05:32:19 +0000
Subject: [PATCH 05/18] adding SPIRV

---
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      | 29 +++++++++++++++++++
 .../CodeGenHLSL/builtins/splitdouble.hlsl     |  9 ++++++
 2 files changed, 38 insertions(+)

diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index ef59ac777796f5..0ad944b9e5d181 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -456,6 +456,35 @@ _HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
 void asuint(double3, out uint3, out uint3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
 void asuint(double4, out uint4, out uint4);
+
+#elif __is_target_arch(spirv)
+
+void asuint(double4 D, out uint4 lowbits, out uint4 highbits) {
+  uint4 bottom = __detail::bit_cast<uint4>(D.xy);
+  uint4 top = __detail::bit_cast<uint4>(D.zw);
+  lowbits = uint4(bottom.x, bottom.z, top.x, top.z);
+  highbits = uint4(bottom.y, bottom.w, top.y, top.w);
+}
+
+void asuint(double3 D, out uint3 lowbits, out uint3 highbits) {
+  uint4 bottom = __detail::bit_cast<uint4>(D.xy);
+  uint2 top = __detail::bit_cast<uint2>(D.z);
+  lowbits = uint3(bottom.x, bottom.z, top.x);
+  highbits = uint3(bottom.y, bottom.w, top.y);
+}
+
+void asuint(double2 D, out uint2 lowbits, out uint2 highbits) {
+  uint4 bottom = __detail::bit_cast<uint4>(D.xy);
+  lowbits = uint2(bottom.x, bottom.z);
+  highbits = uint2(bottom.y, bottom.w);
+}
+
+void asuint(double D, out uint lowbits, out uint highbits) {
+  uint2 bottom = __detail::bit_cast<uint2>(D);
+  lowbits = uint(bottom.x);
+  highbits = uint(bottom.y);
+}
+
 #endif
 
 //===----------------------------------------------------------------------===//
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
index 1711c344792aee..8febc500d3c2b9 100644
--- a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -1,10 +1,15 @@
 // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv--vulkan-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefix=SPIRV
 
 
 // CHECK: define {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
 // CHECK: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
 // CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
 // CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble
+// SPIRV: [[REG:%.*]] = load double, ptr [[VALD]].addr
+// SPIRV: call spir_func void {{.*}}asuint{{.*}}(double {{.*}} [[REG]], {{.*}})
 float test_scalar(double D) {
   uint A, B;
   asuint(D, A, B);
@@ -16,6 +21,10 @@ float test_scalar(double D) {
 // CHECK-NEXT: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALREG]])
 // CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
 // CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble
+// SPIRV: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr
+// SPIRV: call spir_func void {{.*}}asuint{{.*}}(<3 x double> {{.*}} [[REG]], {{.*}})
 float3 test_vector(double3 D) {
   uint3 A, B;
   asuint(D, A, B);

>From 549cf717ba4bbbbf13d212f2c9ad165c8d9eb9df Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 26 Sep 2024 17:31:19 +0000
Subject: [PATCH 06/18] fixing hlsl-lang-targets-spirv.hlsl test

---
 clang/test/Driver/hlsl-lang-targets-spirv.hlsl | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/clang/test/Driver/hlsl-lang-targets-spirv.hlsl b/clang/test/Driver/hlsl-lang-targets-spirv.hlsl
index 61b10e1648c52b..5928c948315f1e 100644
--- a/clang/test/Driver/hlsl-lang-targets-spirv.hlsl
+++ b/clang/test/Driver/hlsl-lang-targets-spirv.hlsl
@@ -3,12 +3,12 @@
 
 // Supported targets
 //
-// RUN: %clang -target dxil-unknown-shadermodel6.2-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -target spirv-unknown-vulkan-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -target spirv-unknown-vulkan1.2-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -target spirv-unknown-vulkan1.3-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -target spirv1.5-unknown-vulkan1.2-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -target spirv1.6-unknown-vulkan1.3-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -cc1 -triple dxil-unknown-shadermodel6.2-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -cc1 -triple spirv-unknown-vulkan-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -cc1 -triple spirv-unknown-vulkan1.2-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -cc1 -triple spirv-unknown-vulkan1.3-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -cc1 -triple spirv1.5-unknown-vulkan1.2-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -cc1 -triple spirv1.6-unknown-vulkan1.3-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
 
 // Empty Vulkan environment
 //

>From 7529ee6d255f6c4ad0ea952aace30d350240a1f1 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Fri, 27 Sep 2024 21:06:27 +0000
Subject: [PATCH 07/18] fixing comments in test

---
 llvm/test/CodeGen/DirectX/splitdouble.ll | 27 ++++++------------------
 1 file changed, 7 insertions(+), 20 deletions(-)

diff --git a/llvm/test/CodeGen/DirectX/splitdouble.ll b/llvm/test/CodeGen/DirectX/splitdouble.ll
index 3ada8c07325431..bfd337042851bd 100644
--- a/llvm/test/CodeGen/DirectX/splitdouble.ll
+++ b/llvm/test/CodeGen/DirectX/splitdouble.ll
@@ -1,12 +1,11 @@
 ; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
-; ModuleID = '../clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl'
-source_filename = "../clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl"
-target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
-target triple = "dxilv1.3-pc-shadermodel6.3-library"
+; Make sure DXILOpLowering is correctly generating the dxil op code call.
 
-; Function Attrs: alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none)
-define noundef float @"?test_scalar@@YAMN at Z"(double noundef %D) local_unnamed_addr #0 {
+
+
+; CHECK-LABEL: define noundef float @test_scalar_double_split
+define noundef float @test_scalar_double_split(double noundef %D) local_unnamed_addr {
 entry:
   ; CHECK: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double %D)
   ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
@@ -19,11 +18,10 @@ entry:
   ret float %conv
 }
 
-; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none)
 declare { i32, i32 } @llvm.dx.splitdouble.i32(double) #1
 
-; Function Attrs: alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none)
-define noundef <3 x float> @"?test_vector@@YAT?$__vector at M$02 at __clang@@T?$__vector at N$02 at 2@@Z"(<3 x double> noundef %D) local_unnamed_addr #0 {
+; CHECK-LABEL: define noundef <3 x float> @test_vector_double_split
+define noundef <3 x float> @test_vector_double_split(<3 x double> noundef %D) local_unnamed_addr {
 entry:
   %0 = extractelement <3 x double> %D, i64 0
   ; CHECK-COUNT-3: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double {{.*}})
@@ -50,14 +48,3 @@ entry:
   %conv = uitofp <3 x i32> %add to <3 x float>
   ret <3 x float> %conv
 }
-
-attributes #0 = { alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none) "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
-attributes #1 = { mustprogress nocallback nofree nosync nounwind willreturn memory(none) }
-
-!llvm.module.flags = !{!0}
-!dx.valver = !{!1}
-!llvm.ident = !{!2}
-
-!0 = !{i32 1, !"wchar_size", i32 4}
-!1 = !{i32 1, i32 8}
-!2 = !{!"clang version 20.0.0git (https://github.com/joaosaffran/llvm-project.git 81476c7ad27010600dc4b4be1d66e7c7db7c10fb)"}

>From d09c11d7e0d130a064e7df7d014b45c94e85bd52 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Fri, 27 Sep 2024 23:25:47 +0000
Subject: [PATCH 08/18] changing intrinsic signature to return vector

---
 clang/lib/CodeGen/CGBuiltin.cpp               |  8 +++--
 .../CodeGenHLSL/builtins/splitdouble.hlsl     | 18 +++++-----
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  4 +--
 llvm/lib/Target/DirectX/DXILOpLowering.cpp    | 19 ++++-------
 llvm/test/CodeGen/DirectX/splitdouble.ll      | 34 +++++++++----------
 5 files changed, 41 insertions(+), 42 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7e442c361d2c32..87912fa3caa973 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -39,6 +39,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsAArch64.h"
@@ -18978,8 +18979,8 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
           Builder->CreateIntrinsic(retType, llvm::Intrinsic::dx_splitdouble,
                                    {arg}, nullptr, "hlsl.asuint");
 
-      Value *arg0 = Builder->CreateExtractValue(CI, 0);
-      Value *arg1 = Builder->CreateExtractValue(CI, 1);
+      Value *arg0 = Builder->CreateExtractElement(CI, (uint64_t)0);
+      Value *arg1 = Builder->CreateExtractElement(CI, (uint64_t)1);
 
       return std::make_pair(arg0, arg1);
     };
@@ -18990,7 +18991,8 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     auto [Op2BaseLValue, Op2TmpLValue] =
         EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
 
-    llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
+    llvm::VectorType *retType =
+        llvm::VectorType::get(Int32Ty, ElementCount::getFixed(2));
 
     if (!Op0->getType()->isVectorTy()) {
       auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType);
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
index 8febc500d3c2b9..f9e2122f4587a1 100644
--- a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -1,11 +1,12 @@
-// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s
 // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv--vulkan-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefix=SPIRV
 
 
 // CHECK: define {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
-// CHECK: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
-// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
-// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+// CHECK: [[REG:%.*]] = load double, ptr [[VALD]].addr, align 8
+// CHECK-NEXT: [[VALRET:%.*]] = call <2 x i32> @llvm.dx.splitdouble.v2i32(double [[REG]])
+// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 0
+// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 1
 // SPIRV: define spir_func {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
 // SPIRV-NOT: @llvm.dx.splitdouble
 // SPIRV: [[REG:%.*]] = load double, ptr [[VALD]].addr
@@ -17,10 +18,11 @@ float test_scalar(double D) {
 }
 
 // CHECK: define {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
-// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[VALD]], i64 [[VALIDX:[0-3]]]
-// CHECK-NEXT: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALREG]])
-// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
-// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+// CHECK: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr, align
+// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[REG]], i64 [[VALIDX:[0-3]]]
+// CHECK-NEXT: [[VALRET:%.*]] = call <2 x i32> @llvm.dx.splitdouble.v2i32(double [[VALREG]])
+// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 0
+// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 1
 // SPIRV: define spir_func {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
 // SPIRV-NOT: @llvm.dx.splitdouble
 // SPIRV: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 996109ed352c7e..849959a972f3d9 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -94,7 +94,7 @@ def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>
 def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 
 def int_dx_splitdouble : DefaultAttrsIntrinsic<
-    [llvm_anyint_ty, LLVMMatchType<0>], 
-    [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], 
+    [llvm_anyvector_ty], 
+    [llvm_double_ty], 
     [IntrNoMem, IntrWillReturn]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 2b62575ea77821..0e3c35fe6a6147 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -268,20 +268,15 @@ class OpLowerer {
     IRBuilder<> &IRB = OpBuilder.getIRB();
 
     for (Use &U : make_early_inc_range(Intrin->uses())) {
-      if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
-
-        assert(EVI->getNumIndices() == 1 &&
-               "splitdouble result should be indexed individually.");
-        if (EVI->getNumIndices() != 1)
-          return make_error<StringError>(
-              "splitdouble result should be indexed individually.",
-              inconvertibleErrorCode());
+      if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
+        if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
 
-        unsigned int IndexVal = EVI->getIndices()[0];
+          size_t IndexVal = IndexOp->getZExtValue();
 
-        auto *OpEVI = IRB.CreateExtractValue(Op, IndexVal);
-        EVI->replaceAllUsesWith(OpEVI);
-        EVI->eraseFromParent();
+          auto *OpEVI = IRB.CreateExtractValue(Op, IndexVal);
+          EEI->replaceAllUsesWith(OpEVI);
+          EEI->eraseFromParent();
+        }
       }
     }
     Intrin->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/splitdouble.ll b/llvm/test/CodeGen/DirectX/splitdouble.ll
index bfd337042851bd..e1da2b2d4a9d66 100644
--- a/llvm/test/CodeGen/DirectX/splitdouble.ll
+++ b/llvm/test/CodeGen/DirectX/splitdouble.ll
@@ -1,8 +1,7 @@
 ; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S --scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
 
-; Make sure DXILOpLowering is correctly generating the dxil op code call.
-
-
+; Make sure DXILOpLowering is correctly generating the dxil op code call, with and without scalarizer.
 
 ; CHECK-LABEL: define noundef float @test_scalar_double_split
 define noundef float @test_scalar_double_split(double noundef %D) local_unnamed_addr {
@@ -10,15 +9,16 @@ entry:
   ; CHECK: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double %D)
   ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
   ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
-  %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
-  %0 = extractvalue { i32, i32 } %hlsl.asuint, 0
-  %1 = extractvalue { i32, i32 } %hlsl.asuint, 1
-  %add = add i32 %0, %1
+  %hlsl.asuint = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %D)
+  %1 = extractelement <2 x i32> %hlsl.asuint, i64 0
+  %2 = extractelement <2 x i32> %hlsl.asuint, i64 1
+  %add = add i32 %1, %2
   %conv = uitofp i32 %add to float
   ret float %conv
 }
 
-declare { i32, i32 } @llvm.dx.splitdouble.i32(double) #1
+declare <2 x i32> @llvm.dx.splitdouble.v2i32(double) #1
+
 
 ; CHECK-LABEL: define noundef <3 x float> @test_vector_double_split
 define noundef <3 x float> @test_vector_double_split(<3 x double> noundef %D) local_unnamed_addr {
@@ -27,21 +27,21 @@ entry:
   ; CHECK-COUNT-3: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double {{.*}})
   ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
   ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
-  %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %0)
-  %1 = extractvalue { i32, i32 } %hlsl.asuint, 0
-  %2 = extractvalue { i32, i32 } %hlsl.asuint, 1
+  %hlsl.asuint = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %0)
+  %1 = extractelement <2 x i32> %hlsl.asuint, i64 0
+  %2 = extractelement <2 x i32> %hlsl.asuint, i64 1
   %3 = insertelement <3 x i32> poison, i32 %1, i64 0
   %4 = insertelement <3 x i32> poison, i32 %2, i64 0
   %5 = extractelement <3 x double> %D, i64 1
-  %hlsl.asuint2 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %5)
-  %6 = extractvalue { i32, i32 } %hlsl.asuint2, 0
-  %7 = extractvalue { i32, i32 } %hlsl.asuint2, 1
+  %hlsl.asuint2 = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %5)
+  %6 = extractelement <2 x i32> %hlsl.asuint2, i64 0
+  %7 = extractelement <2 x i32> %hlsl.asuint2, i64 1
   %8 = insertelement <3 x i32> %3, i32 %6, i64 1
   %9 = insertelement <3 x i32> %4, i32 %7, i64 1
   %10 = extractelement <3 x double> %D, i64 2
-  %hlsl.asuint3 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %10)
-  %11 = extractvalue { i32, i32 } %hlsl.asuint3, 0
-  %12 = extractvalue { i32, i32 } %hlsl.asuint3, 1
+  %hlsl.asuint3 = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %10)
+  %11 = extractelement <2 x i32> %hlsl.asuint3, i64 0
+  %12 = extractelement <2 x i32> %hlsl.asuint3, i64 1
   %13 = insertelement <3 x i32> %8, i32 %11, i64 2
   %14 = insertelement <3 x i32> %9, i32 %12, i64 2
   %add = add <3 x i32> %13, %14

>From fb3f08fdbd25c24f10b299f66a28ff00f86c2a57 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Sat, 28 Sep 2024 06:19:42 +0000
Subject: [PATCH 09/18] pushing original changes

---
 clang/lib/CodeGen/CGBuiltin.cpp               |  7 +++--
 .../CodeGenHLSL/builtins/splitdouble.hlsl     | 20 +++++++-------
 .../test/Driver/hlsl-lang-targets-spirv.hlsl  | 12 ++++-----
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  2 +-
 llvm/lib/Target/DirectX/DXILOpLowering.cpp    | 19 +++++++++-----
 llvm/test/CodeGen/DirectX/splitdouble.ll      | 26 +++++++++----------
 6 files changed, 45 insertions(+), 41 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 87912fa3caa973..ddb0418265a7ff 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18979,8 +18979,8 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
           Builder->CreateIntrinsic(retType, llvm::Intrinsic::dx_splitdouble,
                                    {arg}, nullptr, "hlsl.asuint");
 
-      Value *arg0 = Builder->CreateExtractElement(CI, (uint64_t)0);
-      Value *arg1 = Builder->CreateExtractElement(CI, (uint64_t)1);
+      Value *arg0 = Builder->CreateExtractValue(CI, 0);
+      Value *arg1 = Builder->CreateExtractValue(CI, 1);
 
       return std::make_pair(arg0, arg1);
     };
@@ -18991,8 +18991,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     auto [Op2BaseLValue, Op2TmpLValue] =
         EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
 
-    llvm::VectorType *retType =
-        llvm::VectorType::get(Int32Ty, ElementCount::getFixed(2));
+    llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
 
     if (!Op0->getType()->isVectorTy()) {
       auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType);
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
index f9e2122f4587a1..b937bb5d4d343d 100644
--- a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -1,12 +1,12 @@
-// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
 // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv--vulkan-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefix=SPIRV
 
 
+
 // CHECK: define {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
-// CHECK: [[REG:%.*]] = load double, ptr [[VALD]].addr, align 8
-// CHECK-NEXT: [[VALRET:%.*]] = call <2 x i32> @llvm.dx.splitdouble.v2i32(double [[REG]])
-// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 0
-// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 1
+// CHECK: [[VALRET:%hlsl.asuint.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
 // SPIRV: define spir_func {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
 // SPIRV-NOT: @llvm.dx.splitdouble
 // SPIRV: [[REG:%.*]] = load double, ptr [[VALD]].addr
@@ -17,12 +17,12 @@ float test_scalar(double D) {
   return A + B;
 }
 
+
 // CHECK: define {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
-// CHECK: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr, align
-// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[REG]], i64 [[VALIDX:[0-3]]]
-// CHECK-NEXT: [[VALRET:%.*]] = call <2 x i32> @llvm.dx.splitdouble.v2i32(double [[VALREG]])
-// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 0
-// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 1
+// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[VALD]], i64 [[VALIDX:[0-3]]]
+// CHECK-NEXT: [[VALRET:%hlsl.asuint.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALREG]])
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
 // SPIRV: define spir_func {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
 // SPIRV-NOT: @llvm.dx.splitdouble
 // SPIRV: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr
diff --git a/clang/test/Driver/hlsl-lang-targets-spirv.hlsl b/clang/test/Driver/hlsl-lang-targets-spirv.hlsl
index 5928c948315f1e..61b10e1648c52b 100644
--- a/clang/test/Driver/hlsl-lang-targets-spirv.hlsl
+++ b/clang/test/Driver/hlsl-lang-targets-spirv.hlsl
@@ -3,12 +3,12 @@
 
 // Supported targets
 //
-// RUN: %clang -cc1 -triple dxil-unknown-shadermodel6.2-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -cc1 -triple spirv-unknown-vulkan-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -cc1 -triple spirv-unknown-vulkan1.2-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -cc1 -triple spirv-unknown-vulkan1.3-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -cc1 -triple spirv1.5-unknown-vulkan1.2-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -cc1 -triple spirv1.6-unknown-vulkan1.3-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -target dxil-unknown-shadermodel6.2-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -target spirv-unknown-vulkan-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -target spirv-unknown-vulkan1.2-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -target spirv-unknown-vulkan1.3-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -target spirv1.5-unknown-vulkan1.2-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -target spirv1.6-unknown-vulkan1.3-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
 
 // Empty Vulkan environment
 //
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 849959a972f3d9..e69bc0d1f502a1 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -94,7 +94,7 @@ def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>
 def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 
 def int_dx_splitdouble : DefaultAttrsIntrinsic<
-    [llvm_anyvector_ty], 
+    [llvm_anyint_ty, LLVMMatchType<0>], 
     [llvm_double_ty], 
     [IntrNoMem, IntrWillReturn]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 0e3c35fe6a6147..340914ff2cf422 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -23,7 +23,9 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/InitializePasses.h"
+#include "llvm/Object/Error.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/ErrorHandling.h"
 
 #define DEBUG_TYPE "dxil-op-lower"
@@ -268,17 +270,20 @@ class OpLowerer {
     IRBuilder<> &IRB = OpBuilder.getIRB();
 
     for (Use &U : make_early_inc_range(Intrin->uses())) {
-      if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
-        if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
+      if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
 
-          size_t IndexVal = IndexOp->getZExtValue();
+        if (EVI->getNumIndices() != 1)
+          return createStringError(std::errc::invalid_argument,
+                                   "Splitdouble has only 2 elements");
 
-          auto *OpEVI = IRB.CreateExtractValue(Op, IndexVal);
-          EEI->replaceAllUsesWith(OpEVI);
-          EEI->eraseFromParent();
-        }
+        size_t IndexVal = EVI->getIndices()[0];
+
+        auto *OpEVI = IRB.CreateExtractValue(Op, IndexVal);
+        EVI->replaceAllUsesWith(OpEVI);
+        EVI->eraseFromParent();
       }
     }
+
     Intrin->eraseFromParent();
 
     return Error::success();
diff --git a/llvm/test/CodeGen/DirectX/splitdouble.ll b/llvm/test/CodeGen/DirectX/splitdouble.ll
index e1da2b2d4a9d66..c62b7dd2371ba2 100644
--- a/llvm/test/CodeGen/DirectX/splitdouble.ll
+++ b/llvm/test/CodeGen/DirectX/splitdouble.ll
@@ -9,10 +9,10 @@ entry:
   ; CHECK: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double %D)
   ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
   ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
-  %hlsl.asuint = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %D)
-  %1 = extractelement <2 x i32> %hlsl.asuint, i64 0
-  %2 = extractelement <2 x i32> %hlsl.asuint, i64 1
-  %add = add i32 %1, %2
+  %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
+  %0 = extractvalue { i32, i32 } %hlsl.asuint, 0
+  %1 = extractvalue { i32, i32 } %hlsl.asuint, 1
+  %add = add i32 %0, %1
   %conv = uitofp i32 %add to float
   ret float %conv
 }
@@ -27,21 +27,21 @@ entry:
   ; CHECK-COUNT-3: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double {{.*}})
   ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
   ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
-  %hlsl.asuint = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %0)
-  %1 = extractelement <2 x i32> %hlsl.asuint, i64 0
-  %2 = extractelement <2 x i32> %hlsl.asuint, i64 1
+  %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %0)
+  %1 = extractvalue { i32, i32 } %hlsl.asuint, 0
+  %2 = extractvalue { i32, i32 } %hlsl.asuint, 1
   %3 = insertelement <3 x i32> poison, i32 %1, i64 0
   %4 = insertelement <3 x i32> poison, i32 %2, i64 0
   %5 = extractelement <3 x double> %D, i64 1
-  %hlsl.asuint2 = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %5)
-  %6 = extractelement <2 x i32> %hlsl.asuint2, i64 0
-  %7 = extractelement <2 x i32> %hlsl.asuint2, i64 1
+  %hlsl.asuint2 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %5)
+  %6 = extractvalue { i32, i32 } %hlsl.asuint2, 0
+  %7 = extractvalue { i32, i32 } %hlsl.asuint2, 1
   %8 = insertelement <3 x i32> %3, i32 %6, i64 1
   %9 = insertelement <3 x i32> %4, i32 %7, i64 1
   %10 = extractelement <3 x double> %D, i64 2
-  %hlsl.asuint3 = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %10)
-  %11 = extractelement <2 x i32> %hlsl.asuint3, i64 0
-  %12 = extractelement <2 x i32> %hlsl.asuint3, i64 1
+  %hlsl.asuint3 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %10)
+  %11 = extractvalue { i32, i32 } %hlsl.asuint3, 0
+  %12 = extractvalue { i32, i32 } %hlsl.asuint3, 1
   %13 = insertelement <3 x i32> %8, i32 %11, i64 2
   %14 = insertelement <3 x i32> %9, i32 %12, i64 2
   %add = add <3 x i32> %13, %14

>From b8660ff9574acbb1a3ce77c3fc82a5528827b398 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Sat, 28 Sep 2024 18:32:45 +0000
Subject: [PATCH 10/18] adding static inline atributes

---
 clang/lib/Headers/hlsl/hlsl_intrinsics.h | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 0ad944b9e5d181..80a63b33ea99e6 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -459,27 +459,27 @@ void asuint(double4, out uint4, out uint4);
 
 #elif __is_target_arch(spirv)
 
-void asuint(double4 D, out uint4 lowbits, out uint4 highbits) {
+static inline void asuint(double4 D, out uint4 lowbits, out uint4 highbits) {
   uint4 bottom = __detail::bit_cast<uint4>(D.xy);
   uint4 top = __detail::bit_cast<uint4>(D.zw);
   lowbits = uint4(bottom.x, bottom.z, top.x, top.z);
   highbits = uint4(bottom.y, bottom.w, top.y, top.w);
 }
 
-void asuint(double3 D, out uint3 lowbits, out uint3 highbits) {
+static inline void asuint(double3 D, out uint3 lowbits, out uint3 highbits) {
   uint4 bottom = __detail::bit_cast<uint4>(D.xy);
   uint2 top = __detail::bit_cast<uint2>(D.z);
   lowbits = uint3(bottom.x, bottom.z, top.x);
   highbits = uint3(bottom.y, bottom.w, top.y);
 }
 
-void asuint(double2 D, out uint2 lowbits, out uint2 highbits) {
+static inline void asuint(double2 D, out uint2 lowbits, out uint2 highbits) {
   uint4 bottom = __detail::bit_cast<uint4>(D.xy);
   lowbits = uint2(bottom.x, bottom.z);
   highbits = uint2(bottom.y, bottom.w);
 }
 
-void asuint(double D, out uint lowbits, out uint highbits) {
+static inline void asuint(double D, out uint lowbits, out uint highbits) {
   uint2 bottom = __detail::bit_cast<uint2>(D);
   lowbits = uint(bottom.x);
   highbits = uint(bottom.y);

>From b3f6965565c8fd15a205cbaa23963d114419c7db Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Tue, 1 Oct 2024 19:40:24 +0000
Subject: [PATCH 11/18] refactoring spirv

---
 clang/lib/CodeGen/CGBuiltin.cpp               |  8 ++---
 clang/lib/CodeGen/CGHLSLRuntime.h             |  1 +
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      | 30 -------------------
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  2 ++
 4 files changed, 7 insertions(+), 34 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index ddb0418265a7ff..345f1a6f8e498a 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18973,10 +18973,10 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
 
     auto emitSplitDouble =
-        [](CGBuilderTy *Builder, llvm::Value *arg,
+        [](CGBuilderTy *Builder, llvm::Intrinsic::ID intrId, llvm::Value *arg,
            llvm::Type *retType) -> std::pair<Value *, Value *> {
       CallInst *CI =
-          Builder->CreateIntrinsic(retType, llvm::Intrinsic::dx_splitdouble,
+          Builder->CreateIntrinsic(retType, intrId,
                                    {arg}, nullptr, "hlsl.asuint");
 
       Value *arg0 = Builder->CreateExtractValue(CI, 0);
@@ -18994,7 +18994,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
 
     if (!Op0->getType()->isVectorTy()) {
-      auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType);
+      auto [arg0, arg1] = emitSplitDouble(&Builder, CGM.getHLSLRuntime().getSplitdoubleIntrinsic(), Op0, retType);
 
       Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
       auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
@@ -19013,7 +19013,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     for (uint64_t idx = 0; idx < Op0VecTy->getNumElements(); idx++) {
       Value *op = Builder.CreateExtractElement(Op0, idx);
 
-      auto [arg0, arg1] = emitSplitDouble(&Builder, op, retType);
+      auto [arg0, arg1] = emitSplitDouble(&Builder, CGM.getHLSLRuntime().getSplitdoubleIntrinsic(), op, retType);
 
       if (idx == 0) {
         inserts.first = Builder.CreateInsertElement(i32VecTy, arg0, idx);
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ff7df41b5c62e7..9fa1a495c5d602 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -88,6 +88,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
   GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(Splitdouble, splitdouble);
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 80a63b33ea99e6..a669a673f65a69 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -447,7 +447,6 @@ template <typename T> constexpr uint asuint(T F) {
 /// \param D The input double.
 /// \param lowbits The output lowbits of D.
 /// \param highbits The highbits lowbits D.
-#if __is_target_arch(dxil)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
 void asuint(double, out uint, out uint);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
@@ -457,35 +456,6 @@ void asuint(double3, out uint3, out uint3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
 void asuint(double4, out uint4, out uint4);
 
-#elif __is_target_arch(spirv)
-
-static inline void asuint(double4 D, out uint4 lowbits, out uint4 highbits) {
-  uint4 bottom = __detail::bit_cast<uint4>(D.xy);
-  uint4 top = __detail::bit_cast<uint4>(D.zw);
-  lowbits = uint4(bottom.x, bottom.z, top.x, top.z);
-  highbits = uint4(bottom.y, bottom.w, top.y, top.w);
-}
-
-static inline void asuint(double3 D, out uint3 lowbits, out uint3 highbits) {
-  uint4 bottom = __detail::bit_cast<uint4>(D.xy);
-  uint2 top = __detail::bit_cast<uint2>(D.z);
-  lowbits = uint3(bottom.x, bottom.z, top.x);
-  highbits = uint3(bottom.y, bottom.w, top.y);
-}
-
-static inline void asuint(double2 D, out uint2 lowbits, out uint2 highbits) {
-  uint4 bottom = __detail::bit_cast<uint4>(D.xy);
-  lowbits = uint2(bottom.x, bottom.z);
-  highbits = uint2(bottom.y, bottom.w);
-}
-
-static inline void asuint(double D, out uint lowbits, out uint highbits) {
-  uint2 bottom = __detail::bit_cast<uint2>(D);
-  lowbits = uint(bottom.x);
-  highbits = uint(bottom.y);
-}
-
-#endif
 
 //===----------------------------------------------------------------------===//
 // atan builtins
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index d9377fe4b91a1a..b586ad1f4d50ce 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2569,6 +2569,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   }
   case Intrinsic::spv_wave_readlane:
     return selectWaveReadLaneAt(ResVReg, ResType, I);
+  case Intrinsic::spv_splitdouble:
+    return selectSplitdouble(ResVReg, ResType, I);
   case Intrinsic::spv_step:
     return selectExtInst(ResVReg, ResType, I, CL::step, GL::Step);
   case Intrinsic::spv_radians:

>From 4263f0ad8f2d6af37b7c3347d38c906f6c039c88 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Wed, 2 Oct 2024 18:57:58 +0000
Subject: [PATCH 12/18] adding dxil codegen

---
 clang/lib/CodeGen/CGBuiltin.cpp               | 60 ++++++-------------
 clang/lib/CodeGen/CGHLSLRuntime.h             |  1 -
 .../CodeGenHLSL/builtins/splitdouble.hlsl     | 26 +++-----
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  2 +-
 llvm/lib/Target/DirectX/DXILOpLowering.cpp    |  6 ++
 llvm/test/CodeGen/DirectX/splitdouble.ll      | 51 +++-------------
 .../test/CodeGen/DirectX/splitdouble_error.ll | 16 +++++
 7 files changed, 57 insertions(+), 105 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/splitdouble_error.ll

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 345f1a6f8e498a..03afeedf553e25 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18967,67 +18967,41 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
             E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
             E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
            "asuint operands types mismatch");
-
     Value *Op0 = EmitScalarExpr(E->getArg(0));
     const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
     const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
 
-    auto emitSplitDouble =
-        [](CGBuilderTy *Builder, llvm::Intrinsic::ID intrId, llvm::Value *arg,
-           llvm::Type *retType) -> std::pair<Value *, Value *> {
-      CallInst *CI =
-          Builder->CreateIntrinsic(retType, intrId,
-                                   {arg}, nullptr, "hlsl.asuint");
-
-      Value *arg0 = Builder->CreateExtractValue(CI, 0);
-      Value *arg1 = Builder->CreateExtractValue(CI, 1);
-
-      return std::make_pair(arg0, arg1);
-    };
-
     CallArgList Args;
     auto [Op1BaseLValue, Op1TmpLValue] =
         EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
     auto [Op2BaseLValue, Op2TmpLValue] =
         EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
 
-    llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
+    if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil) {
 
-    if (!Op0->getType()->isVectorTy()) {
-      auto [arg0, arg1] = emitSplitDouble(&Builder, CGM.getHLSLRuntime().getSplitdoubleIntrinsic(), Op0, retType);
-
-      Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
-      auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
-
-      EmitWritebacks(*this, Args);
-      return s;
-    }
+      llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
 
-    auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
+      if (Op0->getType()->isVectorTy()) {
+        auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
 
-    llvm::VectorType *i32VecTy = llvm::VectorType::get(
-        Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
+        llvm::VectorType *i32VecTy = llvm::VectorType::get(
+            Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
+        retType = llvm::StructType::get(i32VecTy, i32VecTy);
+      }
 
-    std::pair<Value *, Value *> inserts = std::make_pair(nullptr, nullptr);
+      CallInst *CI =
+          Builder.CreateIntrinsic(retType, Intrinsic::dx_splitdouble, {Op0},
+                                  nullptr, "hlsl.splitdouble");
 
-    for (uint64_t idx = 0; idx < Op0VecTy->getNumElements(); idx++) {
-      Value *op = Builder.CreateExtractElement(Op0, idx);
+      Value *arg0 = Builder.CreateExtractValue(CI, 0);
+      Value *arg1 = Builder.CreateExtractValue(CI, 1);
 
-      auto [arg0, arg1] = emitSplitDouble(&Builder, CGM.getHLSLRuntime().getSplitdoubleIntrinsic(), op, retType);
+      Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+      auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
 
-      if (idx == 0) {
-        inserts.first = Builder.CreateInsertElement(i32VecTy, arg0, idx);
-        inserts.second = Builder.CreateInsertElement(i32VecTy, arg1, idx);
-      } else {
-        inserts.first = Builder.CreateInsertElement(inserts.first, arg0, idx);
-        inserts.second = Builder.CreateInsertElement(inserts.second, arg1, idx);
-      }
+      EmitWritebacks(*this, Args);
+      return s;
     }
-
-    Builder.CreateStore(inserts.first, Op1TmpLValue.getAddress());
-    auto *s = Builder.CreateStore(inserts.second, Op2TmpLValue.getAddress());
-    EmitWritebacks(*this, Args);
-    return s;
   }
   }
   return nullptr;
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 9fa1a495c5d602..ff7df41b5c62e7 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -88,7 +88,6 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
   GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
-  GENERATE_HLSL_INTRINSIC_FUNCTION(Splitdouble, splitdouble);
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
index b937bb5d4d343d..4f3a2330af924e 100644
--- a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -1,33 +1,23 @@
 // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
-// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv--vulkan-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefix=SPIRV
 
 
 
-// CHECK: define {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
-// CHECK: [[VALRET:%hlsl.asuint.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
+// CHECK: define {{.*}} i32 {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
 // CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
 // CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
-// SPIRV: define spir_func {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
-// SPIRV-NOT: @llvm.dx.splitdouble
-// SPIRV: [[REG:%.*]] = load double, ptr [[VALD]].addr
-// SPIRV: call spir_func void {{.*}}asuint{{.*}}(double {{.*}} [[REG]], {{.*}})
-float test_scalar(double D) {
+uint test_scalar(double D) {
   uint A, B;
   asuint(D, A, B);
   return A + B;
 }
 
 
-// CHECK: define {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
-// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[VALD]], i64 [[VALIDX:[0-3]]]
-// CHECK-NEXT: [[VALRET:%hlsl.asuint.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALREG]])
-// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
-// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
-// SPIRV: define spir_func {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
-// SPIRV-NOT: @llvm.dx.splitdouble
-// SPIRV: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr
-// SPIRV: call spir_func void {{.*}}asuint{{.*}}(<3 x double> {{.*}} [[REG]], {{.*}})
-float3 test_vector(double3 D) {
+// CHECK: define {{.*}} <3 x i32> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = {{.*}} call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> [[VALD]])
+// CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 1
+uint3 test_vector(double3 D) {
   uint3 A, B;
   asuint(D, A, B);
   return A + B;
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index e69bc0d1f502a1..996109ed352c7e 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -95,6 +95,6 @@ def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>
 
 def int_dx_splitdouble : DefaultAttrsIntrinsic<
     [llvm_anyint_ty, LLVMMatchType<0>], 
-    [llvm_double_ty], 
+    [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], 
     [IntrNoMem, IntrWillReturn]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 340914ff2cf422..344f7bb517c2bc 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -493,6 +493,12 @@ class OpLowerer {
 
       Value *Arg0 = CI->getArgOperand(0);
 
+      if (Arg0->getType()->isVectorTy()) {
+        return make_error<StringError>(
+            "splitdouble doesn't support lowering vector types.",
+            inconvertibleErrorCode());
+      }
+
       Type *NewRetTy = OpBuilder.getResSplitDoubleType(M.getContext());
 
       std::array<Value *, 1> Args{Arg0};
diff --git a/llvm/test/CodeGen/DirectX/splitdouble.ll b/llvm/test/CodeGen/DirectX/splitdouble.ll
index c62b7dd2371ba2..6da3b5797b4cba 100644
--- a/llvm/test/CodeGen/DirectX/splitdouble.ll
+++ b/llvm/test/CodeGen/DirectX/splitdouble.ll
@@ -1,50 +1,17 @@
-; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
-; RUN: opt -S --scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
+; RUN: opt -S --scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
 
-; Make sure DXILOpLowering is correctly generating the dxil op code call, with and without scalarizer.
+; Make sure DXILOpLowering is correctly generating the dxil op, with and without scalarizer.
 
-; CHECK-LABEL: define noundef float @test_scalar_double_split
-define noundef float @test_scalar_double_split(double noundef %D) local_unnamed_addr {
+; CHECK-LABEL: define noundef i32 @test_scalar_double_split
+define noundef i32 @test_scalar_double_split(double noundef %D) local_unnamed_addr {
 entry:
   ; CHECK: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double %D)
   ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
   ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
-  %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
-  %0 = extractvalue { i32, i32 } %hlsl.asuint, 0
-  %1 = extractvalue { i32, i32 } %hlsl.asuint, 1
+  %hlsl.splitdouble = call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
+  %0 = extractvalue { i32, i32 } %hlsl.splitdouble, 0
+  %1 = extractvalue { i32, i32 } %hlsl.splitdouble, 1
   %add = add i32 %0, %1
-  %conv = uitofp i32 %add to float
-  ret float %conv
-}
-
-declare <2 x i32> @llvm.dx.splitdouble.v2i32(double) #1
-
-
-; CHECK-LABEL: define noundef <3 x float> @test_vector_double_split
-define noundef <3 x float> @test_vector_double_split(<3 x double> noundef %D) local_unnamed_addr {
-entry:
-  %0 = extractelement <3 x double> %D, i64 0
-  ; CHECK-COUNT-3: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double {{.*}})
-  ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
-  ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
-  %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %0)
-  %1 = extractvalue { i32, i32 } %hlsl.asuint, 0
-  %2 = extractvalue { i32, i32 } %hlsl.asuint, 1
-  %3 = insertelement <3 x i32> poison, i32 %1, i64 0
-  %4 = insertelement <3 x i32> poison, i32 %2, i64 0
-  %5 = extractelement <3 x double> %D, i64 1
-  %hlsl.asuint2 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %5)
-  %6 = extractvalue { i32, i32 } %hlsl.asuint2, 0
-  %7 = extractvalue { i32, i32 } %hlsl.asuint2, 1
-  %8 = insertelement <3 x i32> %3, i32 %6, i64 1
-  %9 = insertelement <3 x i32> %4, i32 %7, i64 1
-  %10 = extractelement <3 x double> %D, i64 2
-  %hlsl.asuint3 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %10)
-  %11 = extractvalue { i32, i32 } %hlsl.asuint3, 0
-  %12 = extractvalue { i32, i32 } %hlsl.asuint3, 1
-  %13 = insertelement <3 x i32> %8, i32 %11, i64 2
-  %14 = insertelement <3 x i32> %9, i32 %12, i64 2
-  %add = add <3 x i32> %13, %14
-  %conv = uitofp <3 x i32> %add to <3 x float>
-  ret <3 x float> %conv
+  ret i32 %add
 }
diff --git a/llvm/test/CodeGen/DirectX/splitdouble_error.ll b/llvm/test/CodeGen/DirectX/splitdouble_error.ll
new file mode 100644
index 00000000000000..acfd52b24c9cc3
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/splitdouble_error.ll
@@ -0,0 +1,16 @@
+; RUN: not opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
+
+; DXIL operation splitdouble doesn't support vector types.
+; CHECK: in function test_vector_double_split
+; CHECK-SAME: splitdouble doesn't support lowering vector types.
+
+define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %D) local_unnamed_addr {
+entry:
+  %hlsl.splitdouble = tail call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %D)
+  %0 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.splitdouble, 0
+  %1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.splitdouble, 1
+  %add = add <3 x i32> %0, %1
+  ret <3 x i32> %add
+}
+
+declare { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double>)

>From 9d25a057d50144dd36a3cb421774ebd58e6e84b7 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Wed, 2 Oct 2024 22:07:42 +0000
Subject: [PATCH 13/18] remove spirv lowering

---
 llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b586ad1f4d50ce..d9377fe4b91a1a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2569,8 +2569,6 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   }
   case Intrinsic::spv_wave_readlane:
     return selectWaveReadLaneAt(ResVReg, ResType, I);
-  case Intrinsic::spv_splitdouble:
-    return selectSplitdouble(ResVReg, ResType, I);
   case Intrinsic::spv_step:
     return selectExtInst(ResVReg, ResType, I, CL::step, GL::Step);
   case Intrinsic::spv_radians:

>From 3175e4cde9ab73da079439bc5aecd724e51dba8c Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 3 Oct 2024 01:24:08 +0000
Subject: [PATCH 14/18] adding spirv codegen

---
 clang/lib/CodeGen/CGBuiltin.cpp | 22 ++++++++++++++++++++++
 1 file changed, 22 insertions(+)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 03afeedf553e25..42537dea54de1b 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -19001,7 +19001,29 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
 
       EmitWritebacks(*this, Args);
       return s;
+    } 
+
+
+    if(!Op0->getType()->isVectorTy()){
+        FixedVectorType *destTy = FixedVectorType::get(Int32Ty, 2);
+        Value *bitcast = Builder.CreateBitCast(Op0, destTy);
+
+        Value *arg0 = Builder.CreateExtractElement(bitcast, 0.0);
+        Value *arg1 = Builder.CreateExtractElement(bitcast, 1.0);
+
+        Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+        auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
+
+        EmitWritebacks(*this, Args);
+        return s;
+    }
+    
+    auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
+
+    for(int idx = 0 ; idx < Op0VecTy -> getNumElements(); idx += 2){
+      
     }
+
   }
   }
   return nullptr;

>From 7338e79d658de83794a83d370eace6438af20f25 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Fri, 4 Oct 2024 01:52:16 +0000
Subject: [PATCH 15/18] adding spirv codegen

---
 clang/lib/CodeGen/CGBuiltin.cpp               | 26 ++++++++++++++++++-
 .../CodeGenHLSL/builtins/splitdouble.hlsl     | 11 ++++++++
 .../SPIRV/hlsl-intrinsics/splitdouble.ll      | 16 ++++++++++++
 .../hlsl-intrinsics/splitdouble_vector.ll     | 14 ++++++++++
 4 files changed, 66 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 42537dea54de1b..396f76705241e8 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -34,8 +34,10 @@
 #include "clang/Frontend/FrontendDiagnostic.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/FloatingPointMode.h"
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/DataLayout.h"
@@ -19020,10 +19022,32 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     
     auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
 
-    for(int idx = 0 ; idx < Op0VecTy -> getNumElements(); idx += 2){
+    int numElements = Op0VecTy -> getNumElements() * 2;
+
+    FixedVectorType *destTy = FixedVectorType::get(Int32Ty, numElements);
       
+    Value *bitcast = Builder.CreateBitCast(Op0, destTy);
+
+    SmallVector<int> lowbitsIndex;
+    SmallVector<int> highbitsIndex;
+
+    for(int idx = 0; idx < numElements; idx += 2){
+      lowbitsIndex.push_back(idx);
+    }
+
+    for(int idx = 1; idx < numElements; idx += 2){
+      highbitsIndex.push_back(idx);
     }
 
+    Value *arg0 = Builder.CreateShuffleVector(bitcast, lowbitsIndex);
+    Value *arg1 = Builder.CreateShuffleVector(bitcast, highbitsIndex);
+
+    Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+    auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
+
+    EmitWritebacks(*this, Args);
+    return s;
+
   }
   }
   return nullptr;
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
index 4f3a2330af924e..e1f42824cfe5ed 100644
--- a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -1,4 +1,5 @@
 // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv-vulkan-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPIRV
 
 
 
@@ -6,6 +7,11 @@
 // CHECK: [[VALRET:%.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
 // CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
 // CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} i32 {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[CAST:%.*]] = bitcast double [[VALD]] to <2 x i32>
+// SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 0
+// SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 1
 uint test_scalar(double D) {
   uint A, B;
   asuint(D, A, B);
@@ -17,6 +23,11 @@ uint test_scalar(double D) {
 // CHECK: [[VALRET:%.*]] = {{.*}} call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> [[VALD]])
 // CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 0
 // CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} <3 x i32> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[CAST:%.*]] = bitcast <3 x double> [[VALD]] to <6 x i32>
+// SPIRV-NEXT: shufflevector <6 x i32> [[CAST]], <6 x i32> poison, <3 x i32> <i32 0, i32 2, i32 4>
+// SPIRV-NEXT: shufflevector <6 x i32> [[CAST]], <6 x i32> poison, <3 x i32> <i32 1, i32 3, i32 5>
 uint3 test_vector(double3 D) {
   uint3 A, B;
   asuint(D, A, B);
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
new file mode 100644
index 00000000000000..c057042c0d142e
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
@@ -0,0 +1,16 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure lowering is correctly generating spirv code.
+
+define spir_func noundef i32 @test_scalar(double noundef %D) local_unnamed_addr {
+entry:
+  ; CHECK: %[[#]] = OpBitcast %[[#]] %[[#]]
+  %0 = bitcast double %D to <2 x i32>
+  ; CHECK: %[[#]] = OpCompositeExtract %[[#]] %[[#]] 0
+  %1 = extractelement <2 x i32> %0, i64 0
+  ; CHECK: %[[#]] = OpCompositeExtract %[[#]] %[[#]] 1
+  %2 = extractelement <2 x i32> %0, i64 1
+  %add = add i32 %1, %2
+  ret i32 %add
+}
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll
new file mode 100644
index 00000000000000..58bd5a046ff3d1
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll
@@ -0,0 +1,14 @@
+; RUN: opt -S -scalarizer -mtriple=spirv-vulkan-library %s 2>&1 | llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown -o - | FileCheck %s
+
+; SPIRV lowering for splitdouble should relly on the scalarizer.
+
+define spir_func noundef <3 x i32> @test_vector(<3 x double> noundef %D) local_unnamed_addr {
+entry:
+  ; CHECK-COUNT-3: %[[#]] = OpBitcast %[[#]] %[[#]]
+  ; CHECK-COUNT-3: %[[#]] = OpCompositeExtract %[[#]] %[[#]] [[0-2]]
+  %0 = bitcast <3 x double> %D to <6 x i32>
+  %1 = shufflevector <6 x i32> %0, <6 x i32> poison, <3 x i32> <i32 0, i32 2, i32 4>
+  %2 = shufflevector <6 x i32> %0, <6 x i32> poison, <3 x i32> <i32 1, i32 3, i32 5>
+  %add = add <3 x i32> %1, %2
+  ret <3 x i32> %add
+}

>From 8a32543ceb02e06fe0d5522b4ca1b65f855d6120 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Mon, 7 Oct 2024 22:15:53 +0000
Subject: [PATCH 16/18] addressing PR comments

---
 clang/include/clang/Basic/Builtins.td         |  14 +-
 clang/lib/CodeGen/CGBuiltin.cpp               | 209 +++++++++++-------
 clang/lib/CodeGen/CGCall.cpp                  |  15 +-
 clang/lib/CodeGen/CGExpr.cpp                  |   7 +-
 clang/lib/CodeGen/CodeGenFunction.h           |   6 +-
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  11 +-
 clang/lib/Sema/SemaHLSL.cpp                   |  56 ++---
 .../CodeGenHLSL/builtins/splitdouble.hlsl     |  80 ++++++-
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |   5 -
 llvm/lib/Target/DirectX/DXIL.td               |   4 +-
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp     |  10 +-
 llvm/lib/Target/DirectX/DXILOpBuilder.h       |   2 +-
 llvm/lib/Target/DirectX/DXILOpLowering.cpp    |  75 ++++---
 llvm/test/CodeGen/DirectX/split-double.ll     |  45 ----
 llvm/test/CodeGen/DirectX/splitdouble.ll      |  86 +++++--
 .../test/CodeGen/DirectX/splitdouble_error.ll |  16 --
 .../SPIRV/hlsl-intrinsics/splitdouble.ll      |  44 +++-
 .../hlsl-intrinsics/splitdouble_vector.ll     |  14 --
 18 files changed, 392 insertions(+), 307 deletions(-)
 delete mode 100644 llvm/test/CodeGen/DirectX/split-double.ll
 delete mode 100644 llvm/test/CodeGen/DirectX/splitdouble_error.ll
 delete mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b35a24eda6da44..9bd67e0cefebc3 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4871,20 +4871,8 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
-def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
-  let Spellings = ["__builtin_hlsl_elementwise_radians"];
-  let Attributes = [NoThrow, Const];
-  let Prototype = "void(...)";
-}
-
-def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
-  let Spellings = ["__builtin_hlsl_elementwise_radians"];
-  let Attributes = [NoThrow, Const];
-  let Prototype = "void(...)";
-}
-
 def HLSLSplitDouble: LangBuiltin<"HLSL_LANG"> {
-  let Spellings = ["__builtin_hlsl_splitdouble"];
+  let Spellings = ["__builtin_hlsl_elementwise_splitdouble"];
   let Attributes = [NoThrow, Const];
   let Prototype = "void(...)";
 }
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 396f76705241e8..05a41d75c1c419 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -17,6 +17,7 @@
 #include "CGObjCRuntime.h"
 #include "CGOpenCLRuntime.h"
 #include "CGRecordLayout.h"
+#include "CGValue.h"
 #include "CodeGenFunction.h"
 #include "CodeGenModule.h"
 #include "ConstantEmitter.h"
@@ -25,8 +26,10 @@
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Attr.h"
 #include "clang/AST/Decl.h"
+#include "clang/AST/Expr.h"
 #include "clang/AST/OSLog.h"
 #include "clang/AST/OperationKinds.h"
+#include "clang/AST/Type.h"
 #include "clang/Basic/TargetBuiltins.h"
 #include "clang/Basic/TargetInfo.h"
 #include "clang/Basic/TargetOptions.h"
@@ -34,14 +37,11 @@
 #include "clang/Frontend/FrontendDiagnostic.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
-#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/FloatingPointMode.h"
 #include "llvm/ADT/SmallPtrSet.h"
-#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/DataLayout.h"
-#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsAArch64.h"
@@ -70,6 +70,7 @@
 #include "llvm/TargetParser/X86TargetParser.h"
 #include <optional>
 #include <sstream>
+#include <utility>
 
 using namespace clang;
 using namespace CodeGen;
@@ -98,6 +99,125 @@ static void initializeAlloca(CodeGenFunction &CGF, AllocaInst *AI, Value *Size,
   I->addAnnotationMetadata("auto-init");
 }
 
+static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {
+  Value *Op0 = CGF->EmitScalarExpr(E->getArg(0));
+  const auto *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
+  const auto *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
+
+  CallArgList Args;
+  LValue Op1TmpLValue =
+      CGF->EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+  LValue Op2TmpLValue =
+      CGF->EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+
+  if (CGF->getTarget().getCXXABI().areArgsDestroyedLeftToRightInCallee())
+    Args.reverseWritebacks();
+
+  Value *LowBits = nullptr;
+  Value *HighBits = nullptr;
+
+  if (CGF->CGM.getTarget().getTriple().isDXIL()) {
+
+    llvm::Type *RetElementTy = CGF->Int32Ty;
+    if (auto *Op0VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>())
+      RetElementTy = llvm::VectorType::get(
+          CGF->Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
+    auto *RetTy = llvm::StructType::get(RetElementTy, RetElementTy);
+
+    CallInst *CI = CGF->Builder.CreateIntrinsic(
+        RetTy, Intrinsic::dx_splitdouble, {Op0}, nullptr, "hlsl.splitdouble");
+
+    LowBits = CGF->Builder.CreateExtractValue(CI, 0);
+    HighBits = CGF->Builder.CreateExtractValue(CI, 1);
+
+  } else {
+    // For Non DXIL targets we generate the instructions.
+    // TODO: This code accounts for known limitations in
+    // SPIR-V and splitdouble. Such should be handled,
+    // in a later compilation stage. After [issue link here]
+    // is fixed, this shall be refactored.
+
+    // casts `<2 x double>` to `<4 x i32>`, then shuffles into high and low
+    // `<2 x i32>` vectors.
+    auto EmitDouble2Cast =
+        [](CodeGenFunction &CGF,
+           Value *DoubleVec2) -> std::pair<Value *, Value *> {
+      Value *BC = CGF.Builder.CreateBitCast(
+          DoubleVec2, FixedVectorType::get(CGF.Int32Ty, 4));
+      Value *LB = CGF.Builder.CreateShuffleVector(BC, {0, 2});
+      Value *HB = CGF.Builder.CreateShuffleVector(BC, {1, 3});
+      return std::make_pair(LB, HB);
+    };
+
+    if (!Op0->getType()->isVectorTy()) {
+      FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);
+      Value *Bitcast = CGF->Builder.CreateBitCast(Op0, DestTy);
+
+      LowBits = CGF->Builder.CreateExtractElement(Bitcast, (uint64_t)0);
+      HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1);
+    } else {
+
+      const auto *TargTy = E->getArg(0)->getType()->getAs<clang::VectorType>();
+
+      int NumElements = TargTy->getNumElements();
+
+      FixedVectorType *UintVec2 = FixedVectorType::get(CGF->Int32Ty, 2);
+
+      switch (NumElements) {
+      case 1: {
+        auto *Bitcast = CGF->Builder.CreateBitCast(Op0, UintVec2);
+
+        LowBits = CGF->Builder.CreateExtractElement(Bitcast, (uint64_t)0);
+        HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1);
+        break;
+      }
+      case 2: {
+        auto [LB, HB] = EmitDouble2Cast(*CGF, Op0);
+        LowBits = LB;
+        HighBits = HB;
+        break;
+      }
+
+      case 3: {
+        auto *Shuff = CGF->Builder.CreateShuffleVector(Op0, {0, 1});
+        auto [LB, HB] = EmitDouble2Cast(*CGF, Shuff);
+
+        auto *EV = CGF->Builder.CreateExtractElement(Op0, 2);
+        auto *ScalarBitcast = CGF->Builder.CreateBitCast(EV, UintVec2);
+
+        LowBits =
+            CGF->Builder.CreateShuffleVector(LB, ScalarBitcast, {0, 1, 2});
+        HighBits =
+            CGF->Builder.CreateShuffleVector(HB, ScalarBitcast, {0, 1, 3});
+        break;
+      }
+      case 4: {
+
+        auto *Shuff1 = CGF->Builder.CreateShuffleVector(Op0, {0, 1});
+        auto [LB1, HB1] = EmitDouble2Cast(*CGF, Shuff1);
+
+        auto *Shuff2 = CGF->Builder.CreateShuffleVector(Op0, {2, 3});
+        auto [LB2, HB2] = EmitDouble2Cast(*CGF, Shuff2);
+
+        LowBits = CGF->Builder.CreateShuffleVector(LB1, LB2, {0, 1, 2, 3});
+        HighBits = CGF->Builder.CreateShuffleVector(HB1, HB2, {0, 1, 3, 3});
+        break;
+      }
+      default: {
+        CGF->CGM.Error(E->getExprLoc(),
+                       "splitdouble doesn't support vectors larger than 4.");
+        return nullptr;
+      }
+      }
+    }
+  }
+  CGF->Builder.CreateStore(LowBits, Op1TmpLValue.getAddress());
+  auto *LastInst =
+      CGF->Builder.CreateStore(HighBits, Op2TmpLValue.getAddress());
+  CGF->EmitWritebacks(Args);
+  return LastInst;
+}
+
 /// getBuiltinLibFunction - Given a builtin id for a function like
 /// "__builtin_fabsf", return a Function* for "fabsf".
 llvm::Constant *CodeGenModule::getBuiltinLibFunction(const FunctionDecl *FD,
@@ -18962,92 +19082,13 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
         nullptr, "hlsl.radians");
   }
-  // This should only be called when targeting DXIL
-  case Builtin::BI__builtin_hlsl_splitdouble: {
+  case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
 
     assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
             E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
             E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
            "asuint operands types mismatch");
-    Value *Op0 = EmitScalarExpr(E->getArg(0));
-    const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
-    const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
-
-    CallArgList Args;
-    auto [Op1BaseLValue, Op1TmpLValue] =
-        EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
-    auto [Op2BaseLValue, Op2TmpLValue] =
-        EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
-
-    if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil) {
-
-      llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
-
-      if (Op0->getType()->isVectorTy()) {
-        auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
-
-        llvm::VectorType *i32VecTy = llvm::VectorType::get(
-            Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
-        retType = llvm::StructType::get(i32VecTy, i32VecTy);
-      }
-
-      CallInst *CI =
-          Builder.CreateIntrinsic(retType, Intrinsic::dx_splitdouble, {Op0},
-                                  nullptr, "hlsl.splitdouble");
-
-      Value *arg0 = Builder.CreateExtractValue(CI, 0);
-      Value *arg1 = Builder.CreateExtractValue(CI, 1);
-
-      Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
-      auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
-
-      EmitWritebacks(*this, Args);
-      return s;
-    } 
-
-
-    if(!Op0->getType()->isVectorTy()){
-        FixedVectorType *destTy = FixedVectorType::get(Int32Ty, 2);
-        Value *bitcast = Builder.CreateBitCast(Op0, destTy);
-
-        Value *arg0 = Builder.CreateExtractElement(bitcast, 0.0);
-        Value *arg1 = Builder.CreateExtractElement(bitcast, 1.0);
-
-        Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
-        auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
-
-        EmitWritebacks(*this, Args);
-        return s;
-    }
-    
-    auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
-
-    int numElements = Op0VecTy -> getNumElements() * 2;
-
-    FixedVectorType *destTy = FixedVectorType::get(Int32Ty, numElements);
-      
-    Value *bitcast = Builder.CreateBitCast(Op0, destTy);
-
-    SmallVector<int> lowbitsIndex;
-    SmallVector<int> highbitsIndex;
-
-    for(int idx = 0; idx < numElements; idx += 2){
-      lowbitsIndex.push_back(idx);
-    }
-
-    for(int idx = 1; idx < numElements; idx += 2){
-      highbitsIndex.push_back(idx);
-    }
-
-    Value *arg0 = Builder.CreateShuffleVector(bitcast, lowbitsIndex);
-    Value *arg1 = Builder.CreateShuffleVector(bitcast, highbitsIndex);
-
-    Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
-    auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
-
-    EmitWritebacks(*this, Args);
-    return s;
-
+    return handleHlslSplitdouble(E, this);
   }
   }
   return nullptr;
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index ebdb04331ea6fb..44c6ec3737adc9 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -40,6 +40,7 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/Type.h"
+#include "llvm/Support/Path.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include <optional>
 using namespace clang;
@@ -4207,12 +4208,6 @@ static void emitWriteback(CodeGenFunction &CGF,
     CGF.EmitBlock(contBB);
 }
 
-static void emitWritebacks(CodeGenFunction &CGF,
-                           const CallArgList &args) {
-  for (const auto &I : args.writebacks())
-    emitWriteback(CGF, I);
-}
-
 static void deactivateArgCleanupsBeforeCall(CodeGenFunction &CGF,
                                             const CallArgList &CallArgs) {
   ArrayRef<CallArgList::CallArgCleanup> Cleanups =
@@ -4681,9 +4676,9 @@ void CallArg::copyInto(CodeGenFunction &CGF, Address Addr) const {
   IsUsed = true;
 }
 
-void CodeGenFunction::EmitWritebacks(CodeGenFunction &CGF,
-                                     const CallArgList &args) {
-  emitWritebacks(CGF, args);
+void CodeGenFunction::EmitWritebacks(const CallArgList &args) {
+  for (const auto &I : args.writebacks())
+    emitWriteback(*this, I);
 }
 
 void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
@@ -5902,7 +5897,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
   // Emit any call-associated writebacks immediately.  Arguably this
   // should happen after any return-value munging.
   if (CallArgs.hasWritebacks())
-    emitWritebacks(*this, CallArgs);
+    CodeGenFunction::EmitWritebacks(CallArgs);
 
   // The stack cleanup for inalloca arguments has to run out of the normal
   // lexical order, so deactivate it and run it manually here.
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 35e6aa8cec6e7e..e90e8da3e9f1ea 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -5478,9 +5478,8 @@ CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) {
   return std::make_pair(BaseLV, TempLV);
 }
 
-std::pair<LValue, LValue>
-CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
-                                    QualType Ty) {
+LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
+                                           CallArgList &Args, QualType Ty) {
 
   auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty);
 
@@ -5495,7 +5494,7 @@ CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
   Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(),
                     LifetimeSize);
   Args.add(RValue::get(TmpAddr, *this), Ty);
-  return std::make_pair(BaseLV, TempLV);
+  return TempLV;
 }
 
 LValue
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index c1d6b4e2e3460d..3ff4458fb32024 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4299,8 +4299,8 @@ class CodeGenFunction : public CodeGenTypeCache {
 
   std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E,
                                                   QualType Ty);
-  std::pair<LValue, LValue> EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
-                                               CallArgList &Args, QualType Ty);
+  LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
+                            QualType Ty);
 
   Address EmitExtVectorElementLValue(LValue V);
 
@@ -5151,7 +5151,7 @@ class CodeGenFunction : public CodeGenTypeCache {
                            unsigned ParmNum);
 
   /// EmitWriteback - Emit callbacks for function.
-  void EmitWritebacks(CodeGenFunction &CGF, const CallArgList &args);
+  void EmitWritebacks(const CallArgList &Args);
 
   /// EmitCallArg - Emit a single call argument.
   void EmitCallArg(CallArgList &args, const Expr *E, QualType ArgType);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index a669a673f65a69..8ade4b27f360fb 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -446,17 +446,16 @@ template <typename T> constexpr uint asuint(T F) {
 /// \brief Split and interprets the lowbits and highbits of double D into uints.
 /// \param D The input double.
 /// \param lowbits The output lowbits of D.
-/// \param highbits The highbits lowbits D.
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
+/// \param highbits The output highbits of D.
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
 void asuint(double, out uint, out uint);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
 void asuint(double2, out uint2, out uint2);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
 void asuint(double3, out uint3, out uint3);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
 void asuint(double4, out uint4, out uint4);
 
-
 //===----------------------------------------------------------------------===//
 // atan builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 993da91817f774..4346863d1869d9 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1698,7 +1698,7 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
   return true;
 }
 
-bool CheckArgTypeIsIncorrect(
+bool CheckArgTypeIsCorrect(
     Sema *S, Expr *Arg, QualType ExpectedType,
     llvm::function_ref<bool(clang::QualType PassedType)> Check) {
   QualType PassedType = Arg->getType();
@@ -1713,12 +1713,12 @@ bool CheckArgTypeIsIncorrect(
   return false;
 }
 
-bool CheckArgsTypesAreCorrect(
+bool CheckAllArgTypesAreCorrect(
     Sema *S, CallExpr *TheCall, QualType ExpectedType,
     llvm::function_ref<bool(clang::QualType PassedType)> Check) {
   for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
     Expr *Arg = TheCall->getArg(i);
-    if (CheckArgTypeIsIncorrect(S, Arg, ExpectedType, Check)) {
+    if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
       return true;
     }
   }
@@ -1729,8 +1729,8 @@ static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
   auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
     return !PassedType->hasFloatingRepresentation();
   };
-  return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                  checkAllFloatTypes);
+  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
+                                    checkAllFloatTypes);
 }
 
 static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
@@ -1741,8 +1741,8 @@ static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
             : PassedType;
     return !BaseType->isHalfType() && !BaseType->isFloat32Type();
   };
-  return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                  checkFloatorHalf);
+  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
+                                    checkFloatorHalf);
 }
 
 static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
@@ -1751,24 +1751,24 @@ static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
       return VecTy->getElementType()->isDoubleType();
     return false;
   };
-  return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
-                                  checkDoubleVector);
+  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
+                                    checkDoubleVector);
 }
 static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
   auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
     return !PassedType->hasIntegerRepresentation() &&
            !PassedType->hasFloatingRepresentation();
   };
-  return CheckArgsTypesAreCorrect(S, TheCall, S->Context.IntTy,
-                                  checkAllSignedTypes);
+  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy,
+                                    checkAllSignedTypes);
 }
 
 static bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) {
   auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool {
     return !PassedType->hasUnsignedIntegerRepresentation();
   };
-  return CheckArgsTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
-                                  checkAllUnsignedTypes);
+  return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
+                                    checkAllUnsignedTypes);
 }
 
 static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
@@ -2083,34 +2083,16 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
-  case Builtin::BI__builtin_hlsl_splitdouble: {
+  case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
 
-    Expr *Op0 = TheCall->getArg(0);
-
-    auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool {
-      return !PassedType->hasFloatingRepresentation();
-    };
-
-    if (CheckArgTypeIsIncorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
-                                CheckIsNotDouble)) {
-      return true;
-    }
-
-    Expr *Op1 = TheCall->getArg(1);
-    Expr *Op2 = TheCall->getArg(2);
-
-    auto CheckIsNotUint = [](clang::QualType PassedType) -> bool {
-      return !PassedType->hasUnsignedIntegerRepresentation();
-    };
-
-    if (CheckArgTypeIsIncorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
-                                CheckIsNotUint) ||
-        CheckArgTypeIsIncorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
-                                CheckIsNotUint)) {
+    if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.DoubleTy, 0) ||
+        CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
+                            1) ||
+        CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
+                            2))
       return true;
-    }
 
     break;
   }
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
index e1f42824cfe5ed..4fad8a122b7422 100644
--- a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -1,5 +1,5 @@
 // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
-// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv-vulkan-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPIRV
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv-vulkan-library %s -fnative-half-type -emit-llvm -O0 -o - | FileCheck %s --check-prefix=SPIRV
 
 
 
@@ -9,7 +9,8 @@
 // CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
 // SPIRV: define spir_func {{.*}} i32 {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
 // SPIRV-NOT: @llvm.dx.splitdouble.i32
-// SPIRV: [[CAST:%.*]] = bitcast double [[VALD]] to <2 x i32>
+// SPIRV: [[REG:%.*]] = load double, ptr [[VALD]].addr, align 8
+// SPIRV-NEXT: [[CAST:%.*]] = bitcast double [[REG]] to <2 x i32>
 // SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 0
 // SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 1
 uint test_scalar(double D) {
@@ -18,18 +19,81 @@ uint test_scalar(double D) {
   return A + B;
 }
 
+// CHECK: define {{.*}} <1 x i32> {{.*}}test_double1{{.*}}(<1 x double> {{.*}} [[VALD:%.*]])
+// CHECK: [[TRUNC:%.*]] = extractelement <1 x double> %D, i64 0
+// CHECK-NEXT: [[VALRET:%.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[TRUNC]])
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} <1 x i32> {{.*}}test_double1{{.*}}(<1 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[REG:%.*]] = load <1 x double>, ptr [[VALD]].addr, align 8
+// SPIRV-NEXT: [[TRUNC:%.*]] = extractelement <1 x double> %1, i64 0
+// SPIRV-NEXT: [[CAST:%.*]] = bitcast double [[TRUNC]] to <2 x i32>
+// SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 0
+// SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 1
+uint1 test_double1(double1 D) {
+  uint A, B;
+  asuint(D, A, B);
+  return A + B;
+}
+
+// CHECK: define {{.*}} <2 x i32> {{.*}}test_vector2{{.*}}(<2 x double> {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = {{.*}} call { <2 x i32>, <2 x i32> } @llvm.dx.splitdouble.v2i32(<2 x double> [[VALD]])
+// CHECK-NEXT: extractvalue { <2 x i32>, <2 x i32> } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { <2 x i32>, <2 x i32> } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} <2 x i32> {{.*}}test_vector2{{.*}}(<2 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[REG:%.*]] = load <2 x double>, ptr [[VALD]].addr, align 16
+// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <2 x double> [[REG]] to <4 x i32>
+// SPIRV-NEXT: [[SHUF1:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
+// SPIRV-NEXT: [[SHUF2:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 1, i32 3>
+uint2 test_vector2(double2 D) {
+  uint2 A, B;
+  asuint(D, A, B);
+  return A + B;
+}
 
-// CHECK: define {{.*}} <3 x i32> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// CHECK: define {{.*}} <3 x i32> {{.*}}test_vector3{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
 // CHECK: [[VALRET:%.*]] = {{.*}} call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> [[VALD]])
 // CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 0
 // CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 1
-// SPIRV: define spir_func {{.*}} <3 x i32> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// SPIRV: define spir_func {{.*}} <3 x i32> {{.*}}test_vector3{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
 // SPIRV-NOT: @llvm.dx.splitdouble.i32
-// SPIRV: [[CAST:%.*]] = bitcast <3 x double> [[VALD]] to <6 x i32>
-// SPIRV-NEXT: shufflevector <6 x i32> [[CAST]], <6 x i32> poison, <3 x i32> <i32 0, i32 2, i32 4>
-// SPIRV-NEXT: shufflevector <6 x i32> [[CAST]], <6 x i32> poison, <3 x i32> <i32 1, i32 3, i32 5>
-uint3 test_vector(double3 D) {
+// SPIRV: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr, align 32
+// SPIRV-NEXT: [[VALRET1:%.*]] = shufflevector <3 x double> [[REG]], <3 x double> poison, <2 x i32> <i32 0, i32 1>
+// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <2 x double> [[VALRET1]] to <4 x i32>
+// SPIRV-NEXT: [[SHUF1:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
+// SPIRV-NEXT: [[SHUF2:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 1, i32 3>
+// SPIRV-NEXT: [[EXTRACT:%.*]] = extractelement <3 x double> [[REG]], i64 2
+// SPIRV-NEXT: [[CAST:%.*]] = bitcast double [[EXTRACT]] to <2 x i32>
+// SPIRV-NEXT: %[[#]] = shufflevector <2 x i32> [[SHUF1]], <2 x i32> [[CAST]], <3 x i32> <i32 0, i32 1, i32 2>
+// SPIRV-NEXT: %[[#]] = shufflevector <2 x i32> [[SHUF2]], <2 x i32> [[CAST]], <3 x i32> <i32 0, i32 1, i32 3>
+uint3 test_vector3(double3 D) {
   uint3 A, B;
   asuint(D, A, B);
   return A + B;
 }
+
+// CHECK: define {{.*}} <4 x i32> {{.*}}test_vector4{{.*}}(<4 x double> {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = {{.*}} call { <4 x i32>, <4 x i32> } @llvm.dx.splitdouble.v4i32(<4 x double> [[VALD]])
+// CHECK-NEXT: extractvalue { <4 x i32>, <4 x i32> } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { <4 x i32>, <4 x i32> } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} <4 x i32> {{.*}}test_vector4{{.*}}(<4 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[REG:%.*]] = load <4 x double>, ptr [[VALD]].addr, align 32
+// SPIRV-NEXT: [[VALRET1:%.*]] = shufflevector <4 x double> [[REG]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
+// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <2 x double> [[VALRET1]] to <4 x i32>
+// SPIRV-NEXT: [[SHUF1:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
+// SPIRV-NEXT: [[SHUF2:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 1, i32 3>
+// SPIRV-NEXT: [[VALRET2:%.*]] = shufflevector <4 x double> [[REG]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
+// SPIRV-NEXT: [[CAST2:%.*]] = bitcast <2 x double> [[VALRET2]] to <4 x i32>
+// SPIRV-NEXT: [[SHUF3:%.*]] = shufflevector <4 x i32> [[CAST2]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
+// SPIRV-NEXT: [[SHUF4:%.*]] = shufflevector <4 x i32> [[CAST2]], <4 x i32> poison, <2 x i32> <i32 1, i32 3>
+// SPIRV-NEXT: shufflevector <2 x i32> [[SHUF1]], <2 x i32> [[SHUF3]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+// SPIRV-NEXT: shufflevector <2 x i32> [[SHUF2]], <2 x i32> [[SHUF4]], <4 x i32> <i32 0, i32 1, i32 3, i32 3>
+
+uint4 test_vector4(double4 D) {
+  uint4 A, B;
+  asuint(D, A, B);
+  return A + B;
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 996109ed352c7e..e30d37f69f781e 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -92,9 +92,4 @@ def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, L
 def int_dx_splitdouble : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>], 
     [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem]>;
 def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
-
-def int_dx_splitdouble : DefaultAttrsIntrinsic<
-    [llvm_anyint_ty, LLVMMatchType<0>], 
-    [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], 
-    [IntrNoMem, IntrWillReturn]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 338cc546348b8d..68ae5de06423c2 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -47,7 +47,7 @@ def ResRetInt32Ty : DXILOpParamType;
 def HandleTy : DXILOpParamType;
 def ResBindTy : DXILOpParamType;
 def ResPropsTy : DXILOpParamType;
-def ResSplitDoubleTy : DXILOpParamType;
+def SplitDoubleTy : DXILOpParamType;
 
 class DXILOpClass;
 
@@ -783,7 +783,7 @@ def FlattenedThreadIdInGroup :  DXILOp<96, flattenedThreadIdInGroup> {
 def SplitDouble :  DXILOp<102, splitDouble> {
   let Doc = "Splits a double into 2 uints";
   let arguments = [OverloadTy];
-  let result = ResSplitDoubleTy;
+  let result = SplitDoubleTy;
   let overloads = [Overloads<DXIL1_0, [DoubleTy]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 982d7849d9bb8b..5d5bb3eacace25 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -229,7 +229,7 @@ static StructType *getResPropsType(LLVMContext &Context) {
   return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties");
 }
 
-static StructType *getResSplitDoubleType(LLVMContext &Context) {
+static StructType *getSplitDoubleType(LLVMContext &Context) {
   if (auto *ST = StructType::getTypeByName(Context, "dx.types.splitdouble"))
     return ST;
   Type *Int32Ty = Type::getInt32Ty(Context);
@@ -273,8 +273,8 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
     return getResBindType(Ctx);
   case OpParamType::ResPropsTy:
     return getResPropsType(Ctx);
-  case OpParamType::ResSplitDoubleTy:
-    return getResSplitDoubleType(Ctx);
+  case OpParamType::SplitDoubleTy:
+    return getSplitDoubleType(Ctx);
   }
   llvm_unreachable("Invalid parameter kind");
   return nullptr;
@@ -476,8 +476,8 @@ StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
   return ::getResRetType(ElementTy);
 }
 
-StructType *DXILOpBuilder::getResSplitDoubleType(LLVMContext &Context) {
-  return ::getResSplitDoubleType(Context);
+StructType *DXILOpBuilder::getSplitDoubleType(LLVMContext &Context) {
+  return ::getSplitDoubleType(Context);
 }
 
 StructType *DXILOpBuilder::getHandleType() {
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 8b1e87c283146c..df5a0240870f4a 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -51,7 +51,7 @@ class DXILOpBuilder {
   StructType *getResRetType(Type *ElementTy);
 
   /// Get the `%dx.types.splitdouble` type.
-  StructType *getResSplitDoubleType(LLVMContext &Context);
+  StructType *getSplitDoubleType(LLVMContext &Context);
 
   /// Get the `%dx.types.Handle` type.
   StructType *getHandleType();
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 344f7bb517c2bc..f7722d77074764 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -23,9 +23,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/InitializePasses.h"
-#include "llvm/Object/Error.h"
 #include "llvm/Pass.h"
-#include "llvm/Support/Error.h"
 #include "llvm/Support/ErrorHandling.h"
 
 #define DEBUG_TYPE "dxil-op-lower"
@@ -131,6 +129,30 @@ class OpLowerer {
     });
   }
 
+  [[nodiscard]] bool replaceFunctionWithNamedStructOp(
+      Function &F, dxil::OpCode DXILOp, Type *NewRetTy,
+      llvm::function_ref<Error(CallInst *CI, CallInst *Op)> ReplaceUses) {
+    bool IsVectorArgExpansion = isVectorArgExpansion(F);
+    return replaceFunction(F, [&](CallInst *CI) -> Error {
+      SmallVector<Value *> Args;
+      OpBuilder.getIRB().SetInsertPoint(CI);
+      if (IsVectorArgExpansion) {
+        SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB());
+        Args.append(NewArgs.begin(), NewArgs.end());
+      } else
+        Args.append(CI->arg_begin(), CI->arg_end());
+
+      Expected<CallInst *> OpCall =
+          OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), NewRetTy);
+      if (Error E = OpCall.takeError())
+        return E;
+      if (Error E = ReplaceUses(CI, *OpCall))
+        return E;
+
+      return Error::success();
+    });
+  }
+
   /// Create a cast between a `target("dx")` type and `dx.types.Handle`, which
   /// is intended to be removed by the end of lowering. This is used to allow
   /// lowering of ops which need to change their return or argument types in a
@@ -267,20 +289,17 @@ class OpLowerer {
   }
 
   Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
-    IRBuilder<> &IRB = OpBuilder.getIRB();
-
     for (Use &U : make_early_inc_range(Intrin->uses())) {
       if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
 
         if (EVI->getNumIndices() != 1)
           return createStringError(std::errc::invalid_argument,
                                    "Splitdouble has only 2 elements");
-
-        size_t IndexVal = EVI->getIndices()[0];
-
-        auto *OpEVI = IRB.CreateExtractValue(Op, IndexVal);
-        EVI->replaceAllUsesWith(OpEVI);
-        EVI->eraseFromParent();
+        EVI->setOperand(0, Op);
+      } else {
+        return make_error<StringError>(
+            "Splitdouble use is not ExtractValueInst",
+            inconvertibleErrorCode());
       }
     }
 
@@ -486,33 +505,6 @@ class OpLowerer {
     });
   }
 
-  [[nodiscard]] bool lowerSplitDouble(Function &F) {
-    IRBuilder<> &IRB = OpBuilder.getIRB();
-    return replaceFunction(F, [&](CallInst *CI) -> Error {
-      IRB.SetInsertPoint(CI);
-
-      Value *Arg0 = CI->getArgOperand(0);
-
-      if (Arg0->getType()->isVectorTy()) {
-        return make_error<StringError>(
-            "splitdouble doesn't support lowering vector types.",
-            inconvertibleErrorCode());
-      }
-
-      Type *NewRetTy = OpBuilder.getResSplitDoubleType(M.getContext());
-
-      std::array<Value *, 1> Args{Arg0};
-      Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
-          OpCode::SplitDouble, Args, CI->getName(), NewRetTy);
-      if (Error E = OpCall.takeError())
-        return E;
-      if (Error E = replaceSplitDoubleCallUsages(CI, *OpCall))
-        return E;
-
-      return Error::success();
-    });
-  }
-
   bool lowerIntrinsics() {
     bool Updated = false;
     bool HasErrors = false;
@@ -541,8 +533,15 @@ class OpLowerer {
       case Intrinsic::dx_typedBufferStore:
         HasErrors |= lowerTypedBufferStore(F);
         break;
+      // TODO: this can be removed when
+      // https://github.com/llvm/llvm-project/issues/113192 is fixed
       case Intrinsic::dx_splitdouble:
-        HasErrors |= lowerSplitDouble(F);
+        HasErrors |= replaceFunctionWithNamedStructOp(
+            F, OpCode::SplitDouble,
+            OpBuilder.getSplitDoubleType(M.getContext()),
+            [&](CallInst *CI, CallInst *Op) {
+              return replaceSplitDoubleCallUsages(CI, Op);
+            });
         break;
       }
       Updated = true;
diff --git a/llvm/test/CodeGen/DirectX/split-double.ll b/llvm/test/CodeGen/DirectX/split-double.ll
deleted file mode 100644
index 759590fa56279b..00000000000000
--- a/llvm/test/CodeGen/DirectX/split-double.ll
+++ /dev/null
@@ -1,45 +0,0 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
-; RUN: opt -passes='function(scalarizer)' -S -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
-
-define void @test_vector_double_split_void(<2 x double> noundef %d) {
-; CHECK-LABEL: define void @test_vector_double_split_void(
-; CHECK-SAME: <2 x double> noundef [[D:%.*]]) {
-; CHECK-NEXT:    [[D_I0:%.*]] = extractelement <2 x double> [[D]], i64 0
-; CHECK-NEXT:    [[HLSL_ASUINT_I0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I0]])
-; CHECK-NEXT:    [[D_I1:%.*]] = extractelement <2 x double> [[D]], i64 1
-; CHECK-NEXT:    [[HLSL_ASUINT_I1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I1]])
-; CHECK-NEXT:    ret void
-;
-  %hlsl.asuint = call { <2 x i32>, <2 x i32> }  @llvm.dx.splitdouble.v2i32(<2 x double> %d)
-  ret void
-}
-
-define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) {
-; CHECK-LABEL: define noundef <3 x i32> @test_vector_double_split(
-; CHECK-SAME: <3 x double> noundef [[D:%.*]]) {
-; CHECK-NEXT:    [[D_I0:%.*]] = extractelement <3 x double> [[D]], i64 0
-; CHECK-NEXT:    [[HLSL_ASUINT_I0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I0]])
-; CHECK-NEXT:    [[D_I1:%.*]] = extractelement <3 x double> [[D]], i64 1
-; CHECK-NEXT:    [[HLSL_ASUINT_I1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I1]])
-; CHECK-NEXT:    [[D_I2:%.*]] = extractelement <3 x double> [[D]], i64 2
-; CHECK-NEXT:    [[HLSL_ASUINT_I2:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I2]])
-; CHECK-NEXT:    [[DOTELEM0:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 0
-; CHECK-NEXT:    [[DOTELEM01:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I1]], 0
-; CHECK-NEXT:    [[DOTELEM02:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I2]], 0
-; CHECK-NEXT:    [[DOTELEM1:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 1
-; CHECK-NEXT:    [[DOTELEM13:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I1]], 1
-; CHECK-NEXT:    [[DOTELEM14:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I2]], 1
-; CHECK-NEXT:    [[DOTI0:%.*]] = add i32 [[DOTELEM0]], [[DOTELEM1]]
-; CHECK-NEXT:    [[DOTI1:%.*]] = add i32 [[DOTELEM01]], [[DOTELEM13]]
-; CHECK-NEXT:    [[DOTI2:%.*]] = add i32 [[DOTELEM02]], [[DOTELEM14]]
-; CHECK-NEXT:    [[DOTUPTO015:%.*]] = insertelement <3 x i32> poison, i32 [[DOTI0]], i64 0
-; CHECK-NEXT:    [[DOTUPTO116:%.*]] = insertelement <3 x i32> [[DOTUPTO015]], i32 [[DOTI1]], i64 1
-; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <3 x i32> [[DOTUPTO116]], i32 [[DOTI2]], i64 2
-; CHECK-NEXT:    ret <3 x i32> [[TMP1]]
-;
-  %hlsl.asuint = call { <3 x i32>, <3 x i32> }  @llvm.dx.splitdouble.v3i32(<3 x double> %d)
-  %1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
-  %2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
-  %3 = add <3 x i32> %1, %2
-  ret <3 x i32> %3
-}
diff --git a/llvm/test/CodeGen/DirectX/splitdouble.ll b/llvm/test/CodeGen/DirectX/splitdouble.ll
index 6da3b5797b4cba..52caad71e210b7 100644
--- a/llvm/test/CodeGen/DirectX/splitdouble.ll
+++ b/llvm/test/CodeGen/DirectX/splitdouble.ll
@@ -1,17 +1,77 @@
-; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
-; RUN: opt -S --scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes='function(scalarizer)' -S -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,NOLOWER
+; RUN: opt -passes='function(scalarizer),module(dxil-op-lower)' -S -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,WITHLOWER
 
-; Make sure DXILOpLowering is correctly generating the dxil op, with and without scalarizer.
-
-; CHECK-LABEL: define noundef i32 @test_scalar_double_split
-define noundef i32 @test_scalar_double_split(double noundef %D) local_unnamed_addr {
-entry:
-  ; CHECK: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double %D)
-  ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
-  ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
+define i32 @test_scalar(double noundef %D) {
+; CHECK-LABEL: define i32 @test_scalar(
+; CHECK-SAME: double noundef [[D:%.*]]) {
+; NOLOWER-NEXT:    [[HLSL_ASUINT_I0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D]])
+; WITHLOWER-NEXT:  [[HLSL_ASUINT_I0:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double [[D]])
+; NOLOWER-NEXT:    [[EV1:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 0
+; NOLOWER-NEXT:    [[EV2:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 1
+; WITHLOWER-NEXT:  [[EV1:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I0]], 0
+; WITHLOWER-NEXT:  [[EV2:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I0]], 1
+; CHECK-NEXT:      [[ADD:%.*]] = add i32 [[EV1]], [[EV2]]
+; CHECK-NEXT:      ret i32 [[ADD]]
+;
   %hlsl.splitdouble = call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
-  %0 = extractvalue { i32, i32 } %hlsl.splitdouble, 0
-  %1 = extractvalue { i32, i32 } %hlsl.splitdouble, 1
-  %add = add i32 %0, %1
+  %1 = extractvalue { i32, i32 } %hlsl.splitdouble, 0
+  %2 = extractvalue { i32, i32 } %hlsl.splitdouble, 1
+  %add = add i32 %1, %2
   ret i32 %add
 }
+
+
+define void @test_vector_double_split_void(<2 x double> noundef %d) {
+; CHECK-LABEL: define void @test_vector_double_split_void(
+; CHECK-SAME: <2 x double> noundef [[D:%.*]]) {
+; CHECK-NEXT:      [[D_I0:%.*]] = extractelement <2 x double> [[D]], i64 0
+; NOLOWER-NEXT:    [[HLSL_ASUINT_I0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I0]])
+; WITHLOWER-NEXT:  [[HLSL_ASUINT_I0:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double [[D_I0]])
+; CHECK-NEXT:      [[D_I1:%.*]] = extractelement <2 x double> [[D]], i64 1
+; NOLOWER-NEXT:    [[HLSL_ASUINT_I1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I1]])
+; WITHLOWER-NEXT:  [[HLSL_ASUINT_I1:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double [[D_I1]])
+; CHECK-NEXT:      ret void
+;
+  %hlsl.asuint = call { <2 x i32>, <2 x i32> }  @llvm.dx.splitdouble.v2i32(<2 x double> %d)
+  ret void
+}
+
+define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) {
+; CHECK-LABEL: define noundef <3 x i32> @test_vector_double_split(
+; CHECK-SAME: <3 x double> noundef [[D:%.*]]) {
+; CHECK-NEXT:      [[D_I0:%.*]] = extractelement <3 x double> [[D]], i64 0
+; NOLOWER-NEXT:    [[HLSL_ASUINT_I0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I0]])
+; WITHLOWER-NEXT:  [[HLSL_ASUINT_I0:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double [[D_I0]])
+; CHECK-NEXT:      [[D_I1:%.*]] = extractelement <3 x double> [[D]], i64 1
+; NOLOWER-NEXT:    [[HLSL_ASUINT_I1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I1]])
+; WITHLOWER-NEXT:  [[HLSL_ASUINT_I1:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double [[D_I1]])
+; CHECK-NEXT:      [[D_I2:%.*]] = extractelement <3 x double> [[D]], i64 2
+; NOLOWER-NEXT:    [[HLSL_ASUINT_I2:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I2]])
+; WITHLOWER-NEXT:  [[HLSL_ASUINT_I2:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double [[D_I2]])
+; NOLOWER-NEXT:    [[DOTELEM0:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 0
+; WITHLOWER-NEXT:  [[DOTELEM0:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I0]], 0
+; NOLOWER-NEXT:    [[DOTELEM01:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I1]], 0
+; WITHLOWER-NEXT:  [[DOTELEM01:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I1]], 0
+; NOLOWER-NEXT:    [[DOTELEM02:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I2]], 0
+; WITHLOWER-NEXT:  [[DOTELEM02:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I2]], 0
+; NOLOWER-NEXT:    [[DOTELEM1:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 1
+; WITHLOWER-NEXT:  [[DOTELEM1:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I0]], 1
+; NOLOWER-NEXT:    [[DOTELEM13:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I1]], 1
+; WITHLOWER-NEXT:  [[DOTELEM13:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I1]], 1
+; NOLOWER-NEXT:    [[DOTELEM14:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I2]], 1
+; WITHLOWER-NEXT:  [[DOTELEM14:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I2]], 1
+; CHECK-NEXT:      [[DOTI0:%.*]] = add i32 [[DOTELEM0]], [[DOTELEM1]]
+; CHECK-NEXT:      [[DOTI1:%.*]] = add i32 [[DOTELEM01]], [[DOTELEM13]]
+; CHECK-NEXT:      [[DOTI2:%.*]] = add i32 [[DOTELEM02]], [[DOTELEM14]]
+; CHECK-NEXT:      [[DOTUPTO015:%.*]] = insertelement <3 x i32> poison, i32 [[DOTI0]], i64 0
+; CHECK-NEXT:      [[DOTUPTO116:%.*]] = insertelement <3 x i32> [[DOTUPTO015]], i32 [[DOTI1]], i64 1
+; CHECK-NEXT:      [[TMP1:%.*]] = insertelement <3 x i32> [[DOTUPTO116]], i32 [[DOTI2]], i64 2
+; CHECK-NEXT:      ret <3 x i32> [[TMP1]]
+;
+  %hlsl.asuint = call { <3 x i32>, <3 x i32> }  @llvm.dx.splitdouble.v3i32(<3 x double> %d)
+  %1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
+  %2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
+  %3 = add <3 x i32> %1, %2
+  ret <3 x i32> %3
+}
diff --git a/llvm/test/CodeGen/DirectX/splitdouble_error.ll b/llvm/test/CodeGen/DirectX/splitdouble_error.ll
deleted file mode 100644
index acfd52b24c9cc3..00000000000000
--- a/llvm/test/CodeGen/DirectX/splitdouble_error.ll
+++ /dev/null
@@ -1,16 +0,0 @@
-; RUN: not opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
-
-; DXIL operation splitdouble doesn't support vector types.
-; CHECK: in function test_vector_double_split
-; CHECK-SAME: splitdouble doesn't support lowering vector types.
-
-define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %D) local_unnamed_addr {
-entry:
-  %hlsl.splitdouble = tail call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %D)
-  %0 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.splitdouble, 0
-  %1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.splitdouble, 1
-  %add = add <3 x i32> %0, %1
-  ret <3 x i32> %add
-}
-
-declare { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double>)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
index c057042c0d142e..2243e6b9944912 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
@@ -3,14 +3,52 @@
 
 ; Make sure lowering is correctly generating spirv code.
 
+; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#scalar_function:]] = OpTypeFunction %[[#int_32]] %[[#double]]
+; CHECK-DAG: %[[#vec_2_int_32:]] = OpTypeVector %[[#int_32]] 2
+; CHECK-DAG: %[[#vec_4_int_32:]] = OpTypeVector %[[#int_32]] 4
+; CHECK-DAG: %[[#vec_3_int_32:]] = OpTypeVector %[[#int_32]] 3
+; CHECK-DAG: %[[#vec_3_double:]] = OpTypeVector %[[#double]] 3
+; CHECK-DAG: %[[#vector_function:]] = OpTypeFunction %[[#vec_3_int_32]] %[[#vec_3_double]]
+; CHECK-DAG: %[[#vec_2_double:]] = OpTypeVector %[[#double]] 2
+
+
 define spir_func noundef i32 @test_scalar(double noundef %D) local_unnamed_addr {
 entry:
-  ; CHECK: %[[#]] = OpBitcast %[[#]] %[[#]]
+  ; CHECK-LABEL: ; -- Begin function test_scalar
+  ; CHECK: %[[#param:]] = OpFunctionParameter %[[#double]]
+  ; CHECK: %[[#bitcast:]] = OpBitcast %[[#vec_2_int_32]] %[[#param]]
   %0 = bitcast double %D to <2 x i32>
-  ; CHECK: %[[#]] = OpCompositeExtract %[[#]] %[[#]] 0
+  ; CHECK: %[[#]] = OpCompositeExtract %[[#int_32]] %[[#bitcast]] 0
   %1 = extractelement <2 x i32> %0, i64 0
-  ; CHECK: %[[#]] = OpCompositeExtract %[[#]] %[[#]] 1
+  ; CHECK: %[[#]] = OpCompositeExtract %[[#int_32]] %[[#bitcast]] 1
   %2 = extractelement <2 x i32> %0, i64 1
   %add = add i32 %1, %2
   ret i32 %add
 }
+
+
+define spir_func noundef <3 x i32> @test_vector(<3 x double> noundef %D) local_unnamed_addr {
+entry:
+  ; CHECK-LABEL: ; -- Begin function test_vector
+  ; CHECK: %[[#param:]] = OpFunctionParameter %[[#vec_3_double]]
+  ; %[[#SHUFF1:]] = OpVectorShuffle %[[#vec_2_double]] %[[#param]] %[[#]] 0 1
+  ; %[[#CAST1:]] = OpBitcast %[[#vec_4_int_32]] %[[#SHUFF1]]
+  ; %[[#SHUFF2:]] = OpVectorShuffle %[[#vec_2_int_32]] %[[#CAST1]] %[[#]] 0 2
+  ; %[[#SHUFF3:]] = OpVectorShuffle %[[#vec_2_int_32]] %[[#CAST1]] %[[#]] 1 3
+  ; %[[#EXTRACT:]] = OpCompositeExtract %[[#double]] %[[#param]] 2
+  ; %[[#CAST2:]] = OpBitcast %[[#vec_2_int_32]] %[[#EXTRACT]]
+  ; %[[#]] = OpVectorShuffle %7 %[[#SHUFF2]] %[[#CAST2]] 0 1 2
+  ; %[[#]] = OpVectorShuffle %7 %[[#SHUFF3]] %[[#CAST2]] 0 1 3
+  %0 = shufflevector <3 x double> %D, <3 x double> poison, <2 x i32> <i32 0, i32 1>
+  %1 = bitcast <2 x double> %0 to <4 x i32>
+  %2 = shufflevector <4 x i32> %1, <4 x i32> poison, <2 x i32> <i32 0, i32 2>
+  %3 = shufflevector <4 x i32> %1, <4 x i32> poison, <2 x i32> <i32 1, i32 3>
+  %4 = extractelement <3 x double> %D, i64 2
+  %5 = bitcast double %4 to <2 x i32>
+  %6 = shufflevector <2 x i32> %2, <2 x i32> %5, <3 x i32> <i32 0, i32 1, i32 2>
+  %7 = shufflevector <2 x i32> %3, <2 x i32> %5, <3 x i32> <i32 0, i32 1, i32 3>
+  %add = add <3 x i32> %6, %7
+  ret <3 x i32> %add
+}
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll
deleted file mode 100644
index 58bd5a046ff3d1..00000000000000
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll
+++ /dev/null
@@ -1,14 +0,0 @@
-; RUN: opt -S -scalarizer -mtriple=spirv-vulkan-library %s 2>&1 | llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown -o - | FileCheck %s
-
-; SPIRV lowering for splitdouble should relly on the scalarizer.
-
-define spir_func noundef <3 x i32> @test_vector(<3 x double> noundef %D) local_unnamed_addr {
-entry:
-  ; CHECK-COUNT-3: %[[#]] = OpBitcast %[[#]] %[[#]]
-  ; CHECK-COUNT-3: %[[#]] = OpCompositeExtract %[[#]] %[[#]] [[0-2]]
-  %0 = bitcast <3 x double> %D to <6 x i32>
-  %1 = shufflevector <6 x i32> %0, <6 x i32> poison, <3 x i32> <i32 0, i32 2, i32 4>
-  %2 = shufflevector <6 x i32> %0, <6 x i32> poison, <3 x i32> <i32 1, i32 3, i32 5>
-  %add = add <3 x i32> %1, %2
-  ret <3 x i32> %add
-}

>From a676944b08172a0327838309e5bde17b4cf96a79 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 24 Oct 2024 17:33:23 +0000
Subject: [PATCH 17/18] Adding Sema test

---
 clang/lib/CodeGen/CGBuiltin.cpp               |  5 +-
 clang/lib/Sema/SemaHLSL.cpp                   | 12 +++++
 .../test/SemaHLSL/BuiltIns/asuint-errors.hlsl | 26 ++++++++++
 .../SemaHLSL/BuiltIns/splitdouble-errors.hlsl | 52 +++++++++++++++++++
 4 files changed, 93 insertions(+), 2 deletions(-)
 create mode 100644 clang/test/SemaHLSL/BuiltIns/splitdouble-errors.hlsl

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 05a41d75c1c419..57ea4ef8b3f781 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -134,8 +134,9 @@ static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {
     // For Non DXIL targets we generate the instructions.
     // TODO: This code accounts for known limitations in
     // SPIR-V and splitdouble. Such should be handled,
-    // in a later compilation stage. After [issue link here]
-    // is fixed, this shall be refactored.
+    // in a later compilation stage. After
+    // https://github.com/llvm/llvm-project/issues/113597 is fixed, this shall
+    // be refactored.
 
     // casts `<2 x double>` to `<4 x i32>`, then shuffles into high and low
     // `<2 x i32>` vectors.
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 4346863d1869d9..7c13294afa7db7 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1745,6 +1745,15 @@ static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
                                     checkFloatorHalf);
 }
 
+static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
+                                  unsigned ArgIndex) {
+  auto *Arg = TheCall->getArg(ArgIndex);
+  if (Arg->isModifiableLvalue(S->getASTContext()) == Expr::MLV_Valid)
+    return false;
+  S->Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_lvalue) << Arg << 0;
+  return true;
+}
+
 static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
   auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
     if (const auto *VecTy = PassedType->getAs<VectorType>())
@@ -2093,6 +2102,9 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
         CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
                             2))
       return true;
+    if (CheckModifiableLValue(&SemaRef, TheCall, 1) ||
+        CheckModifiableLValue(&SemaRef, TheCall, 2))
+      return true;
 
     break;
   }
diff --git a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
index b9a920f9f1b4d0..4adb0555c35be6 100644
--- a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
@@ -27,3 +27,29 @@ uint test_asuint_half(half p1) {
     // expected-note at hlsl/hlsl_detail.h:* {{candidate template ignored: could not match 'vector<half, N>' against 'half'}}
     // expected-note at hlsl/hlsl_detail.h:* {{candidate template ignored: substitution failure [with U = uint, T = half]: no type named 'Type'}}
 }
+
+void test_asuint_first_arg_const(double D) {
+  const uint A = 0;
+  uint B;
+  asuint(D, A, B);
+ // expected-error at hlsl/hlsl_intrinsics.h:* {{read-only variable is not assignable}} 
+}
+
+void test_asuint_second_arg_const(double D) {
+  const uint A = 0;
+  uint B;
+  asuint(D, B, A);
+ // expected-error at hlsl/hlsl_intrinsics.h:* {{read-only variable is not assignable}} 
+}
+
+void test_asuint_imidiate_value(double D) {
+  uint B;
+  asuint(D, B, 1);
+ // expected-error at -1 {{cannot bind non-lvalue argument 1 to out paramemter}} 
+}
+
+void test_asuint_expr(double D) {
+  uint B;
+  asuint(D, B, B + 1);
+ // expected-error at -1 {{cannot bind non-lvalue argument B + 1 to out paramemter}} 
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/splitdouble-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/splitdouble-errors.hlsl
new file mode 100644
index 00000000000000..5bac11c4216a98
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/splitdouble-errors.hlsl
@@ -0,0 +1,52 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -verify
+
+void test_no_second_arg(double D) {
+  __builtin_hlsl_elementwise_splitdouble(D);
+ // expected-error at -1 {{too few arguments to function call, expected 3, have 1}} 
+}
+
+void test_no_third_arg(double D) {
+  uint A;
+  __builtin_hlsl_elementwise_splitdouble(D, A);
+ // expected-error at -1 {{too few arguments to function call, expected 3, have 2}} 
+}
+
+void test_too_many_arg(double D) {
+  uint A, B, C;
+  __builtin_hlsl_elementwise_splitdouble(D, A, B, C);
+ // expected-error at -1 {{too many arguments to function call, expected 3, have 4}} 
+}
+
+void test_first_arg_type_mismatch(bool3 D) {
+  uint3 A, B;
+  __builtin_hlsl_elementwise_splitdouble(D, A, B);
+ // expected-error at -1 {{invalid operand of type 'bool3' (aka 'vector<bool, 3>') where 'double' or a vector of such type is required}} 
+}
+
+void test_second_arg_type_mismatch(double D) {
+  bool A;
+  uint B;
+  __builtin_hlsl_elementwise_splitdouble(D, A, B);
+ // expected-error at -1 {{invalid operand of type 'bool' where 'unsigned int' or a vector of such type is required}} 
+}
+
+void test_third_arg_type_mismatch(double D) {
+  bool A;
+  uint B;
+  __builtin_hlsl_elementwise_splitdouble(D, B, A);
+ // expected-error at -1 {{invalid operand of type 'bool' where 'unsigned int' or a vector of such type is required}} 
+}
+
+void test_const_second_arg(double D) {
+  const uint A = 1;
+  uint B;
+  __builtin_hlsl_elementwise_splitdouble(D, A, B);
+ // expected-error at -1 {{cannot bind non-lvalue argument}} 
+}
+
+void test_const_third_arg(double D) {
+  uint A;
+  const uint B = 2;
+  __builtin_hlsl_elementwise_splitdouble(D, A, B);
+ // expected-error at -1 {{cannot bind non-lvalue argument}} 
+}

>From 104479d3973b29ece6c0064d06d08bdfa9195cca Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Sat, 26 Oct 2024 01:52:34 +0000
Subject: [PATCH 18/18] addressing comments

---
 clang/lib/CodeGen/CGBuiltin.cpp               | 88 ++++---------------
 clang/lib/CodeGen/CGCall.cpp                  |  2 +-
 clang/lib/Sema/SemaHLSL.cpp                   |  8 +-
 .../CodeGenHLSL/builtins/splitdouble.hlsl     | 64 ++++++--------
 .../SemaHLSL/BuiltIns/splitdouble-errors.hlsl | 30 ++++++-
 llvm/test/CodeGen/DirectX/splitdouble.ll      |  1 -
 .../SPIRV/hlsl-intrinsics/splitdouble.ll      | 36 +++-----
 7 files changed, 91 insertions(+), 138 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 57ea4ef8b3f781..86aecf494cbcda 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -132,23 +132,6 @@ static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {
 
   } else {
     // For Non DXIL targets we generate the instructions.
-    // TODO: This code accounts for known limitations in
-    // SPIR-V and splitdouble. Such should be handled,
-    // in a later compilation stage. After
-    // https://github.com/llvm/llvm-project/issues/113597 is fixed, this shall
-    // be refactored.
-
-    // casts `<2 x double>` to `<4 x i32>`, then shuffles into high and low
-    // `<2 x i32>` vectors.
-    auto EmitDouble2Cast =
-        [](CodeGenFunction &CGF,
-           Value *DoubleVec2) -> std::pair<Value *, Value *> {
-      Value *BC = CGF.Builder.CreateBitCast(
-          DoubleVec2, FixedVectorType::get(CGF.Int32Ty, 4));
-      Value *LB = CGF.Builder.CreateShuffleVector(BC, {0, 2});
-      Value *HB = CGF.Builder.CreateShuffleVector(BC, {1, 3});
-      return std::make_pair(LB, HB);
-    };
 
     if (!Op0->getType()->isVectorTy()) {
       FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);
@@ -157,58 +140,25 @@ static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {
       LowBits = CGF->Builder.CreateExtractElement(Bitcast, (uint64_t)0);
       HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1);
     } else {
-
-      const auto *TargTy = E->getArg(0)->getType()->getAs<clang::VectorType>();
-
-      int NumElements = TargTy->getNumElements();
-
-      FixedVectorType *UintVec2 = FixedVectorType::get(CGF->Int32Ty, 2);
-
-      switch (NumElements) {
-      case 1: {
-        auto *Bitcast = CGF->Builder.CreateBitCast(Op0, UintVec2);
-
-        LowBits = CGF->Builder.CreateExtractElement(Bitcast, (uint64_t)0);
-        HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1);
-        break;
-      }
-      case 2: {
-        auto [LB, HB] = EmitDouble2Cast(*CGF, Op0);
-        LowBits = LB;
-        HighBits = HB;
-        break;
-      }
-
-      case 3: {
-        auto *Shuff = CGF->Builder.CreateShuffleVector(Op0, {0, 1});
-        auto [LB, HB] = EmitDouble2Cast(*CGF, Shuff);
-
-        auto *EV = CGF->Builder.CreateExtractElement(Op0, 2);
-        auto *ScalarBitcast = CGF->Builder.CreateBitCast(EV, UintVec2);
-
-        LowBits =
-            CGF->Builder.CreateShuffleVector(LB, ScalarBitcast, {0, 1, 2});
-        HighBits =
-            CGF->Builder.CreateShuffleVector(HB, ScalarBitcast, {0, 1, 3});
-        break;
-      }
-      case 4: {
-
-        auto *Shuff1 = CGF->Builder.CreateShuffleVector(Op0, {0, 1});
-        auto [LB1, HB1] = EmitDouble2Cast(*CGF, Shuff1);
-
-        auto *Shuff2 = CGF->Builder.CreateShuffleVector(Op0, {2, 3});
-        auto [LB2, HB2] = EmitDouble2Cast(*CGF, Shuff2);
-
-        LowBits = CGF->Builder.CreateShuffleVector(LB1, LB2, {0, 1, 2, 3});
-        HighBits = CGF->Builder.CreateShuffleVector(HB1, HB2, {0, 1, 3, 3});
-        break;
-      }
-      default: {
-        CGF->CGM.Error(E->getExprLoc(),
-                       "splitdouble doesn't support vectors larger than 4.");
-        return nullptr;
-      }
+      int NumElements = 1;
+      if (const auto *VecTy =
+              E->getArg(0)->getType()->getAs<clang::VectorType>())
+        NumElements = VecTy->getNumElements();
+
+      FixedVectorType *Uint32VecTy =
+          FixedVectorType::get(CGF->Int32Ty, NumElements * 2);
+      Value *Uint32Vec = CGF->Builder.CreateBitCast(Op0, Uint32VecTy);
+      if (NumElements == 1) {
+        LowBits = CGF->Builder.CreateExtractElement(Uint32Vec, (uint64_t)0);
+        HighBits = CGF->Builder.CreateExtractElement(Uint32Vec, 1);
+      } else {
+        SmallVector<int> EvenMask, OddMask;
+        for (int I = 0, E = NumElements; I != E; ++I) {
+          EvenMask.push_back(I * 2);
+          OddMask.push_back(I * 2 + 1);
+        }
+        LowBits = CGF->Builder.CreateShuffleVector(Uint32Vec, EvenMask);
+        HighBits = CGF->Builder.CreateShuffleVector(Uint32Vec, OddMask);
       }
     }
   }
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 44c6ec3737adc9..4a3d82cf59e0f5 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -5897,7 +5897,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
   // Emit any call-associated writebacks immediately.  Arguably this
   // should happen after any return-value munging.
   if (CallArgs.hasWritebacks())
-    CodeGenFunction::EmitWritebacks(CallArgs);
+    EmitWritebacks(CallArgs);
 
   // The stack cleanup for inalloca arguments has to run out of the normal
   // lexical order, so deactivate it and run it manually here.
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 7c13294afa7db7..a472538236e2d9 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1748,9 +1748,11 @@ static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
                                   unsigned ArgIndex) {
   auto *Arg = TheCall->getArg(ArgIndex);
-  if (Arg->isModifiableLvalue(S->getASTContext()) == Expr::MLV_Valid)
+  SourceLocation OrigLoc = Arg->getExprLoc();
+  if (Arg->IgnoreCasts()->isModifiableLvalue(S->Context, &OrigLoc) ==
+      Expr::MLV_Valid)
     return false;
-  S->Diag(Arg->getBeginLoc(), diag::error_hlsl_inout_lvalue) << Arg << 0;
+  S->Diag(OrigLoc, diag::error_hlsl_inout_lvalue) << Arg << 0;
   return true;
 }
 
@@ -2102,10 +2104,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
         CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
                             2))
       return true;
+
     if (CheckModifiableLValue(&SemaRef, TheCall, 1) ||
         CheckModifiableLValue(&SemaRef, TheCall, 2))
       return true;
-
     break;
   }
   case Builtin::BI__builtin_elementwise_acos:
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
index 4fad8a122b7422..a883c9d5cc3555 100644
--- a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -4,13 +4,14 @@
 
 
 // CHECK: define {{.*}} i32 {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
-// CHECK: [[VALRET:%.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
+// CHECK:      [[VALRET:%.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
 // CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
 // CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+//
 // SPIRV: define spir_func {{.*}} i32 {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
-// SPIRV-NOT: @llvm.dx.splitdouble.i32
-// SPIRV: [[REG:%.*]] = load double, ptr [[VALD]].addr, align 8
-// SPIRV-NEXT: [[CAST:%.*]] = bitcast double [[REG]] to <2 x i32>
+// SPIRV-NOT:  @llvm.dx.splitdouble.i32
+// SPIRV:      [[LOAD:%.*]] = load double, ptr [[VALD]].addr, align 8
+// SPIRV-NEXT: [[CAST:%.*]] = bitcast double [[LOAD]] to <2 x i32>
 // SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 0
 // SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 1
 uint test_scalar(double D) {
@@ -20,14 +21,15 @@ uint test_scalar(double D) {
 }
 
 // CHECK: define {{.*}} <1 x i32> {{.*}}test_double1{{.*}}(<1 x double> {{.*}} [[VALD:%.*]])
-// CHECK: [[TRUNC:%.*]] = extractelement <1 x double> %D, i64 0
+// CHECK:      [[TRUNC:%.*]] = extractelement <1 x double> %D, i64 0
 // CHECK-NEXT: [[VALRET:%.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[TRUNC]])
 // CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
 // CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+//
 // SPIRV: define spir_func {{.*}} <1 x i32> {{.*}}test_double1{{.*}}(<1 x double> {{.*}} [[VALD:%.*]])
-// SPIRV-NOT: @llvm.dx.splitdouble.i32
-// SPIRV: [[REG:%.*]] = load <1 x double>, ptr [[VALD]].addr, align 8
-// SPIRV-NEXT: [[TRUNC:%.*]] = extractelement <1 x double> %1, i64 0
+// SPIRV-NOT:  @llvm.dx.splitdouble.i32
+// SPIRV:      [[LOAD:%.*]] = load <1 x double>, ptr [[VALD]].addr, align 8
+// SPIRV-NEXT: [[TRUNC:%.*]] = extractelement <1 x double> [[LOAD]], i64 0
 // SPIRV-NEXT: [[CAST:%.*]] = bitcast double [[TRUNC]] to <2 x i32>
 // SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 0
 // SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 1
@@ -38,13 +40,14 @@ uint1 test_double1(double1 D) {
 }
 
 // CHECK: define {{.*}} <2 x i32> {{.*}}test_vector2{{.*}}(<2 x double> {{.*}} [[VALD:%.*]])
-// CHECK: [[VALRET:%.*]] = {{.*}} call { <2 x i32>, <2 x i32> } @llvm.dx.splitdouble.v2i32(<2 x double> [[VALD]])
+// CHECK:      [[VALRET:%.*]] = {{.*}} call { <2 x i32>, <2 x i32> } @llvm.dx.splitdouble.v2i32(<2 x double> [[VALD]])
 // CHECK-NEXT: extractvalue { <2 x i32>, <2 x i32> } [[VALRET]], 0
 // CHECK-NEXT: extractvalue { <2 x i32>, <2 x i32> } [[VALRET]], 1
+//
 // SPIRV: define spir_func {{.*}} <2 x i32> {{.*}}test_vector2{{.*}}(<2 x double> {{.*}} [[VALD:%.*]])
-// SPIRV-NOT: @llvm.dx.splitdouble.i32
-// SPIRV: [[REG:%.*]] = load <2 x double>, ptr [[VALD]].addr, align 16
-// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <2 x double> [[REG]] to <4 x i32>
+// SPIRV-NOT:  @llvm.dx.splitdouble.i32
+// SPIRV:      [[LOAD:%.*]] = load <2 x double>, ptr [[VALD]].addr, align 16
+// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <2 x double> [[LOAD]] to <4 x i32>
 // SPIRV-NEXT: [[SHUF1:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
 // SPIRV-NEXT: [[SHUF2:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 1, i32 3>
 uint2 test_vector2(double2 D) {
@@ -54,20 +57,16 @@ uint2 test_vector2(double2 D) {
 }
 
 // CHECK: define {{.*}} <3 x i32> {{.*}}test_vector3{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
-// CHECK: [[VALRET:%.*]] = {{.*}} call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> [[VALD]])
+// CHECK:      [[VALRET:%.*]] = {{.*}} call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> [[VALD]])
 // CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 0
 // CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 1
+//
 // SPIRV: define spir_func {{.*}} <3 x i32> {{.*}}test_vector3{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
-// SPIRV-NOT: @llvm.dx.splitdouble.i32
-// SPIRV: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr, align 32
-// SPIRV-NEXT: [[VALRET1:%.*]] = shufflevector <3 x double> [[REG]], <3 x double> poison, <2 x i32> <i32 0, i32 1>
-// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <2 x double> [[VALRET1]] to <4 x i32>
-// SPIRV-NEXT: [[SHUF1:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
-// SPIRV-NEXT: [[SHUF2:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 1, i32 3>
-// SPIRV-NEXT: [[EXTRACT:%.*]] = extractelement <3 x double> [[REG]], i64 2
-// SPIRV-NEXT: [[CAST:%.*]] = bitcast double [[EXTRACT]] to <2 x i32>
-// SPIRV-NEXT: %[[#]] = shufflevector <2 x i32> [[SHUF1]], <2 x i32> [[CAST]], <3 x i32> <i32 0, i32 1, i32 2>
-// SPIRV-NEXT: %[[#]] = shufflevector <2 x i32> [[SHUF2]], <2 x i32> [[CAST]], <3 x i32> <i32 0, i32 1, i32 3>
+// SPIRV-NOT:  @llvm.dx.splitdouble.i32
+// SPIRV:      [[LOAD:%.*]] = load <3 x double>, ptr [[VALD]].addr, align 32
+// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <3 x double> [[LOAD]] to <6 x i32>
+// SPIRV-NEXT: [[SHUF1:%.*]] = shufflevector <6 x i32> [[CAST1]], <6 x i32> poison, <3 x i32> <i32 0, i32 2, i32 4>
+// SPIRV-NEXT: [[SHUF2:%.*]] = shufflevector <6 x i32> [[CAST1]], <6 x i32> poison, <3 x i32> <i32 1, i32 3, i32 5>
 uint3 test_vector3(double3 D) {
   uint3 A, B;
   asuint(D, A, B);
@@ -75,23 +74,16 @@ uint3 test_vector3(double3 D) {
 }
 
 // CHECK: define {{.*}} <4 x i32> {{.*}}test_vector4{{.*}}(<4 x double> {{.*}} [[VALD:%.*]])
-// CHECK: [[VALRET:%.*]] = {{.*}} call { <4 x i32>, <4 x i32> } @llvm.dx.splitdouble.v4i32(<4 x double> [[VALD]])
+// CHECK:      [[VALRET:%.*]] = {{.*}} call { <4 x i32>, <4 x i32> } @llvm.dx.splitdouble.v4i32(<4 x double> [[VALD]])
 // CHECK-NEXT: extractvalue { <4 x i32>, <4 x i32> } [[VALRET]], 0
 // CHECK-NEXT: extractvalue { <4 x i32>, <4 x i32> } [[VALRET]], 1
+//
 // SPIRV: define spir_func {{.*}} <4 x i32> {{.*}}test_vector4{{.*}}(<4 x double> {{.*}} [[VALD:%.*]])
 // SPIRV-NOT: @llvm.dx.splitdouble.i32
-// SPIRV: [[REG:%.*]] = load <4 x double>, ptr [[VALD]].addr, align 32
-// SPIRV-NEXT: [[VALRET1:%.*]] = shufflevector <4 x double> [[REG]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
-// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <2 x double> [[VALRET1]] to <4 x i32>
-// SPIRV-NEXT: [[SHUF1:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
-// SPIRV-NEXT: [[SHUF2:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 1, i32 3>
-// SPIRV-NEXT: [[VALRET2:%.*]] = shufflevector <4 x double> [[REG]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
-// SPIRV-NEXT: [[CAST2:%.*]] = bitcast <2 x double> [[VALRET2]] to <4 x i32>
-// SPIRV-NEXT: [[SHUF3:%.*]] = shufflevector <4 x i32> [[CAST2]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
-// SPIRV-NEXT: [[SHUF4:%.*]] = shufflevector <4 x i32> [[CAST2]], <4 x i32> poison, <2 x i32> <i32 1, i32 3>
-// SPIRV-NEXT: shufflevector <2 x i32> [[SHUF1]], <2 x i32> [[SHUF3]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
-// SPIRV-NEXT: shufflevector <2 x i32> [[SHUF2]], <2 x i32> [[SHUF4]], <4 x i32> <i32 0, i32 1, i32 3, i32 3>
-
+// SPIRV:      [[LOAD:%.*]] = load <4 x double>, ptr [[VALD]].addr, align 32
+// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <4 x double> [[LOAD]] to <8 x i32>
+// SPIRV-NEXT: [[SHUF1:%.*]] = shufflevector <8 x i32> [[CAST1]], <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+// SPIRV-NEXT: [[SHUF2:%.*]] = shufflevector <8 x i32> [[CAST1]], <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
 uint4 test_vector4(double4 D) {
   uint4 A, B;
   asuint(D, A, B);
diff --git a/clang/test/SemaHLSL/BuiltIns/splitdouble-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/splitdouble-errors.hlsl
index 5bac11c4216a98..18d2b692b335b9 100644
--- a/clang/test/SemaHLSL/BuiltIns/splitdouble-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/splitdouble-errors.hlsl
@@ -41,12 +41,36 @@ void test_const_second_arg(double D) {
   const uint A = 1;
   uint B;
   __builtin_hlsl_elementwise_splitdouble(D, A, B);
- // expected-error at -1 {{cannot bind non-lvalue argument}} 
+ // expected-error at -1 {{cannot bind non-lvalue argument A to out paramemter}} 
 }
 
 void test_const_third_arg(double D) {
   uint A;
-  const uint B = 2;
+  const uint B = 1;
   __builtin_hlsl_elementwise_splitdouble(D, A, B);
- // expected-error at -1 {{cannot bind non-lvalue argument}} 
+ // expected-error at -1 {{cannot bind non-lvalue argument B to out paramemter}} 
+}
+
+void test_number_second_arg(double D) {
+  uint B;
+  __builtin_hlsl_elementwise_splitdouble(D, (uint)1, B);
+ // expected-error at -1 {{cannot bind non-lvalue argument (uint)1 to out paramemter}} 
+}
+
+void test_number_third_arg(double D) {
+  uint B;
+  __builtin_hlsl_elementwise_splitdouble(D, B, (uint)1);
+ // expected-error at -1 {{cannot bind non-lvalue argument (uint)1 to out paramemter}} 
+}
+
+void test_expr_second_arg(double D) {
+  uint B;
+  __builtin_hlsl_elementwise_splitdouble(D, B+1, B);
+ // expected-error at -1 {{cannot bind non-lvalue argument B + 1 to out paramemter}} 
+}
+
+void test_expr_third_arg(double D) {
+  uint B;
+  __builtin_hlsl_elementwise_splitdouble(D, B, B+1);
+ // expected-error at -1 {{cannot bind non-lvalue argument B + 1 to out paramemter}} 
 }
diff --git a/llvm/test/CodeGen/DirectX/splitdouble.ll b/llvm/test/CodeGen/DirectX/splitdouble.ll
index 52caad71e210b7..1443ba6269255a 100644
--- a/llvm/test/CodeGen/DirectX/splitdouble.ll
+++ b/llvm/test/CodeGen/DirectX/splitdouble.ll
@@ -1,4 +1,3 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
 ; RUN: opt -passes='function(scalarizer)' -S -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,NOLOWER
 ; RUN: opt -passes='function(scalarizer),module(dxil-op-lower)' -S -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,WITHLOWER
 
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
index 2243e6b9944912..d18b16b843c37b 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
@@ -4,14 +4,10 @@
 ; Make sure lowering is correctly generating spirv code.
 
 ; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#vec_2_double:]] = OpTypeVector %[[#double]] 2
 ; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32 0
-; CHECK-DAG: %[[#scalar_function:]] = OpTypeFunction %[[#int_32]] %[[#double]]
 ; CHECK-DAG: %[[#vec_2_int_32:]] = OpTypeVector %[[#int_32]] 2
 ; CHECK-DAG: %[[#vec_4_int_32:]] = OpTypeVector %[[#int_32]] 4
-; CHECK-DAG: %[[#vec_3_int_32:]] = OpTypeVector %[[#int_32]] 3
-; CHECK-DAG: %[[#vec_3_double:]] = OpTypeVector %[[#double]] 3
-; CHECK-DAG: %[[#vector_function:]] = OpTypeFunction %[[#vec_3_int_32]] %[[#vec_3_double]]
-; CHECK-DAG: %[[#vec_2_double:]] = OpTypeVector %[[#double]] 2
 
 
 define spir_func noundef i32 @test_scalar(double noundef %D) local_unnamed_addr {
@@ -29,26 +25,16 @@ entry:
 }
 
 
-define spir_func noundef <3 x i32> @test_vector(<3 x double> noundef %D) local_unnamed_addr {
+define spir_func noundef <2 x i32> @test_vector(<2 x double> noundef %D) local_unnamed_addr {
 entry:
   ; CHECK-LABEL: ; -- Begin function test_vector
-  ; CHECK: %[[#param:]] = OpFunctionParameter %[[#vec_3_double]]
-  ; %[[#SHUFF1:]] = OpVectorShuffle %[[#vec_2_double]] %[[#param]] %[[#]] 0 1
-  ; %[[#CAST1:]] = OpBitcast %[[#vec_4_int_32]] %[[#SHUFF1]]
-  ; %[[#SHUFF2:]] = OpVectorShuffle %[[#vec_2_int_32]] %[[#CAST1]] %[[#]] 0 2
-  ; %[[#SHUFF3:]] = OpVectorShuffle %[[#vec_2_int_32]] %[[#CAST1]] %[[#]] 1 3
-  ; %[[#EXTRACT:]] = OpCompositeExtract %[[#double]] %[[#param]] 2
-  ; %[[#CAST2:]] = OpBitcast %[[#vec_2_int_32]] %[[#EXTRACT]]
-  ; %[[#]] = OpVectorShuffle %7 %[[#SHUFF2]] %[[#CAST2]] 0 1 2
-  ; %[[#]] = OpVectorShuffle %7 %[[#SHUFF3]] %[[#CAST2]] 0 1 3
-  %0 = shufflevector <3 x double> %D, <3 x double> poison, <2 x i32> <i32 0, i32 1>
-  %1 = bitcast <2 x double> %0 to <4 x i32>
-  %2 = shufflevector <4 x i32> %1, <4 x i32> poison, <2 x i32> <i32 0, i32 2>
-  %3 = shufflevector <4 x i32> %1, <4 x i32> poison, <2 x i32> <i32 1, i32 3>
-  %4 = extractelement <3 x double> %D, i64 2
-  %5 = bitcast double %4 to <2 x i32>
-  %6 = shufflevector <2 x i32> %2, <2 x i32> %5, <3 x i32> <i32 0, i32 1, i32 2>
-  %7 = shufflevector <2 x i32> %3, <2 x i32> %5, <3 x i32> <i32 0, i32 1, i32 3>
-  %add = add <3 x i32> %6, %7
-  ret <3 x i32> %add
+  ; CHECK: %[[#param:]] = OpFunctionParameter %[[#vec_2_double]]
+  ; CHECK: %[[#CAST1:]] = OpBitcast %[[#vec_4_int_32]] %[[#param]]
+  ; CHECK: %[[#SHUFF2:]] = OpVectorShuffle %[[#vec_2_int_32]] %[[#CAST1]] %[[#]] 0 2
+  ; CHECK: %[[#SHUFF3:]] = OpVectorShuffle %[[#vec_2_int_32]] %[[#CAST1]] %[[#]] 1 3
+  %0 = bitcast <2 x double> %D to <4 x i32>
+  %1 = shufflevector <4 x i32> %0, <4 x i32> poison, <2 x i32> <i32 0, i32 2>
+  %2 = shufflevector <4 x i32> %0, <4 x i32> poison, <2 x i32> <i32 1, i32 3>
+  %add = add <2 x i32> %1, %2
+  ret <2 x i32> %add
 }



More information about the cfe-commits mailing list