[clang] 481bce0 - Adding splitdouble HLSL function (#109331)
via cfe-commits
cfe-commits at lists.llvm.org
Mon Oct 28 13:27:04 PDT 2024
Author: joaosaffran
Date: 2024-10-28T13:26:59-07:00
New Revision: 481bce018ea8872277f79102842eaf8a55f634a2
URL: https://github.com/llvm/llvm-project/commit/481bce018ea8872277f79102842eaf8a55f634a2
DIFF: https://github.com/llvm/llvm-project/commit/481bce018ea8872277f79102842eaf8a55f634a2.diff
LOG: Adding splitdouble HLSL function (#109331)
- Adding hlsl `splitdouble` intrinsics
- Adding DXIL lowering
- Adding SPIRV lowering
- Adding test
Fixes: #108901
---------
Co-authored-by: Joao Saffran <jderezende at microsoft.com>
Added:
clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
clang/test/SemaHLSL/BuiltIns/splitdouble-errors.hlsl
llvm/test/CodeGen/DirectX/splitdouble.ll
llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
Modified:
clang/include/clang/Basic/Builtins.td
clang/lib/CodeGen/CGBuiltin.cpp
clang/lib/CodeGen/CGCall.cpp
clang/lib/CodeGen/CGExpr.cpp
clang/lib/CodeGen/CodeGenFunction.h
clang/lib/Headers/hlsl/hlsl_intrinsics.h
clang/lib/Sema/SemaHLSL.cpp
clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
llvm/lib/Target/DirectX/DXIL.td
llvm/lib/Target/DirectX/DXILOpBuilder.cpp
llvm/lib/Target/DirectX/DXILOpBuilder.h
llvm/lib/Target/DirectX/DXILOpLowering.cpp
Removed:
llvm/test/CodeGen/DirectX/split-double.ll
################################################################################
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..9bd67e0cefebc3 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4871,6 +4871,12 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLSplitDouble: LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_elementwise_splitdouble"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index a57c95d5b96672..65d7f5c54a1913 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -17,6 +17,7 @@
#include "CGObjCRuntime.h"
#include "CGOpenCLRuntime.h"
#include "CGRecordLayout.h"
+#include "CGValue.h"
#include "CodeGenFunction.h"
#include "CodeGenModule.h"
#include "ConstantEmitter.h"
@@ -25,8 +26,10 @@
#include "clang/AST/ASTContext.h"
#include "clang/AST/Attr.h"
#include "clang/AST/Decl.h"
+#include "clang/AST/Expr.h"
#include "clang/AST/OSLog.h"
#include "clang/AST/OperationKinds.h"
+#include "clang/AST/Type.h"
#include "clang/Basic/TargetBuiltins.h"
#include "clang/Basic/TargetInfo.h"
#include "clang/Basic/TargetOptions.h"
@@ -67,6 +70,7 @@
#include "llvm/TargetParser/X86TargetParser.h"
#include <optional>
#include <sstream>
+#include <utility>
using namespace clang;
using namespace CodeGen;
@@ -95,6 +99,76 @@ static void initializeAlloca(CodeGenFunction &CGF, AllocaInst *AI, Value *Size,
I->addAnnotationMetadata("auto-init");
}
+static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {
+ Value *Op0 = CGF->EmitScalarExpr(E->getArg(0));
+ const auto *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
+ const auto *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
+
+ CallArgList Args;
+ LValue Op1TmpLValue =
+ CGF->EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+ LValue Op2TmpLValue =
+ CGF->EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+
+ if (CGF->getTarget().getCXXABI().areArgsDestroyedLeftToRightInCallee())
+ Args.reverseWritebacks();
+
+ Value *LowBits = nullptr;
+ Value *HighBits = nullptr;
+
+ if (CGF->CGM.getTarget().getTriple().isDXIL()) {
+
+ llvm::Type *RetElementTy = CGF->Int32Ty;
+ if (auto *Op0VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>())
+ RetElementTy = llvm::VectorType::get(
+ CGF->Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
+ auto *RetTy = llvm::StructType::get(RetElementTy, RetElementTy);
+
+ CallInst *CI = CGF->Builder.CreateIntrinsic(
+ RetTy, Intrinsic::dx_splitdouble, {Op0}, nullptr, "hlsl.splitdouble");
+
+ LowBits = CGF->Builder.CreateExtractValue(CI, 0);
+ HighBits = CGF->Builder.CreateExtractValue(CI, 1);
+
+ } else {
+ // For Non DXIL targets we generate the instructions.
+
+ if (!Op0->getType()->isVectorTy()) {
+ FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);
+ Value *Bitcast = CGF->Builder.CreateBitCast(Op0, DestTy);
+
+ LowBits = CGF->Builder.CreateExtractElement(Bitcast, (uint64_t)0);
+ HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1);
+ } else {
+ int NumElements = 1;
+ if (const auto *VecTy =
+ E->getArg(0)->getType()->getAs<clang::VectorType>())
+ NumElements = VecTy->getNumElements();
+
+ FixedVectorType *Uint32VecTy =
+ FixedVectorType::get(CGF->Int32Ty, NumElements * 2);
+ Value *Uint32Vec = CGF->Builder.CreateBitCast(Op0, Uint32VecTy);
+ if (NumElements == 1) {
+ LowBits = CGF->Builder.CreateExtractElement(Uint32Vec, (uint64_t)0);
+ HighBits = CGF->Builder.CreateExtractElement(Uint32Vec, 1);
+ } else {
+ SmallVector<int> EvenMask, OddMask;
+ for (int I = 0, E = NumElements; I != E; ++I) {
+ EvenMask.push_back(I * 2);
+ OddMask.push_back(I * 2 + 1);
+ }
+ LowBits = CGF->Builder.CreateShuffleVector(Uint32Vec, EvenMask);
+ HighBits = CGF->Builder.CreateShuffleVector(Uint32Vec, OddMask);
+ }
+ }
+ }
+ CGF->Builder.CreateStore(LowBits, Op1TmpLValue.getAddress());
+ auto *LastInst =
+ CGF->Builder.CreateStore(HighBits, Op2TmpLValue.getAddress());
+ CGF->EmitWritebacks(Args);
+ return LastInst;
+}
+
/// getBuiltinLibFunction - Given a builtin id for a function like
/// "__builtin_fabsf", return a Function* for "fabsf".
llvm::Constant *CodeGenModule::getBuiltinLibFunction(const FunctionDecl *FD,
@@ -18959,6 +19033,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
nullptr, "hlsl.radians");
}
+ case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
+
+ assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
+ E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
+ E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
+ "asuint operands types mismatch");
+ return handleHlslSplitdouble(E, this);
+ }
}
return nullptr;
}
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 64e60f0616d77b..8f4f5d3ed81601 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -40,6 +40,7 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Type.h"
+#include "llvm/Support/Path.h"
#include "llvm/Transforms/Utils/Local.h"
#include <optional>
using namespace clang;
@@ -4243,12 +4244,6 @@ static void emitWriteback(CodeGenFunction &CGF,
CGF.EmitBlock(contBB);
}
-static void emitWritebacks(CodeGenFunction &CGF,
- const CallArgList &args) {
- for (const auto &I : args.writebacks())
- emitWriteback(CGF, I);
-}
-
static void deactivateArgCleanupsBeforeCall(CodeGenFunction &CGF,
const CallArgList &CallArgs) {
ArrayRef<CallArgList::CallArgCleanup> Cleanups =
@@ -4717,6 +4712,11 @@ void CallArg::copyInto(CodeGenFunction &CGF, Address Addr) const {
IsUsed = true;
}
+void CodeGenFunction::EmitWritebacks(const CallArgList &args) {
+ for (const auto &I : args.writebacks())
+ emitWriteback(*this, I);
+}
+
void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
QualType type) {
DisableDebugLocationUpdates Dis(*this, E);
@@ -5940,7 +5940,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
// Emit any call-associated writebacks immediately. Arguably this
// should happen after any return-value munging.
if (CallArgs.hasWritebacks())
- emitWritebacks(*this, CallArgs);
+ EmitWritebacks(CallArgs);
// The stack cleanup for inalloca arguments has to run out of the normal
// lexical order, so deactivate it and run it manually here.
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index e0ea65bcaf3637..e90e8da3e9f1ea 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -5460,9 +5460,8 @@ LValue CodeGenFunction::EmitOpaqueValueLValue(const OpaqueValueExpr *e) {
return getOrCreateOpaqueLValueMapping(e);
}
-void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
- CallArgList &Args, QualType Ty) {
-
+std::pair<LValue, LValue>
+CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) {
// Emitting the casted temporary through an opaque value.
LValue BaseLV = EmitLValue(E->getArgLValue());
OpaqueValueMappingData::bind(*this, E->getOpaqueArgLValue(), BaseLV);
@@ -5476,6 +5475,13 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
TempLV);
OpaqueValueMappingData::bind(*this, E->getCastedTemporary(), TempLV);
+ return std::make_pair(BaseLV, TempLV);
+}
+
+LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
+ CallArgList &Args, QualType Ty) {
+
+ auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty);
llvm::Value *Addr = TempLV.getAddress().getBasePointer();
llvm::Type *ElTy = ConvertTypeForMem(TempLV.getType());
@@ -5488,6 +5494,7 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(),
LifetimeSize);
Args.add(RValue::get(TmpAddr, *this), Ty);
+ return TempLV;
}
LValue
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 750a6cc24badca..3ff4458fb32024 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4296,8 +4296,11 @@ class CodeGenFunction : public CodeGenTypeCache {
LValue EmitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *E);
LValue EmitOpaqueValueLValue(const OpaqueValueExpr *e);
LValue EmitHLSLArrayAssignLValue(const BinaryOperator *E);
- void EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
- QualType Ty);
+
+ std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E,
+ QualType Ty);
+ LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
+ QualType Ty);
Address EmitExtVectorElementLValue(LValue V);
@@ -5147,6 +5150,9 @@ class CodeGenFunction : public CodeGenTypeCache {
SourceLocation ArgLoc, AbstractCallee AC,
unsigned ParmNum);
+ /// EmitWriteback - Emit callbacks for function.
+ void EmitWritebacks(const CallArgList &Args);
+
/// EmitCallArg - Emit a single call argument.
void EmitCallArg(CallArgList &args, const Expr *E, QualType ArgType);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..8ade4b27f360fb 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -438,6 +438,24 @@ template <typename T> constexpr uint asuint(T F) {
return __detail::bit_cast<uint, T>(F);
}
+//===----------------------------------------------------------------------===//
+// asuint splitdouble builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn void asuint(double D, out uint lowbits, out int highbits)
+/// \brief Split and interprets the lowbits and highbits of double D into uints.
+/// \param D The input double.
+/// \param lowbits The output lowbits of D.
+/// \param highbits The output highbits of D.
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
+void asuint(double, out uint, out uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
+void asuint(double2, out uint2, out uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
+void asuint(double3, out uint3, out uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
+void asuint(double4, out uint4, out uint4);
+
//===----------------------------------------------------------------------===//
// atan builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 1f6c5b8d4561bc..a472538236e2d9 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1698,18 +1698,27 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
return true;
}
-static bool CheckArgsTypesAreCorrect(
+bool CheckArgTypeIsCorrect(
+ Sema *S, Expr *Arg, QualType ExpectedType,
+ llvm::function_ref<bool(clang::QualType PassedType)> Check) {
+ QualType PassedType = Arg->getType();
+ if (Check(PassedType)) {
+ if (auto *VecTyA = PassedType->getAs<VectorType>())
+ ExpectedType = S->Context.getVectorType(
+ ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
+ S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
+ << PassedType << ExpectedType << 1 << 0 << 0;
+ return true;
+ }
+ return false;
+}
+
+bool CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
- QualType PassedType = TheCall->getArg(i)->getType();
- if (Check(PassedType)) {
- if (auto *VecTyA = PassedType->getAs<VectorType>())
- ExpectedType = S->Context.getVectorType(
- ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
- S->Diag(TheCall->getArg(0)->getBeginLoc(),
- diag::err_typecheck_convert_incompatible)
- << PassedType << ExpectedType << 1 << 0 << 0;
+ Expr *Arg = TheCall->getArg(i);
+ if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
return true;
}
}
@@ -1720,8 +1729,8 @@ static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasFloatingRepresentation();
};
- return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
- checkAllFloatTypes);
+ return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
+ checkAllFloatTypes);
}
static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
@@ -1732,8 +1741,19 @@ static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
: PassedType;
return !BaseType->isHalfType() && !BaseType->isFloat32Type();
};
- return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
- checkFloatorHalf);
+ return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
+ checkFloatorHalf);
+}
+
+static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
+ unsigned ArgIndex) {
+ auto *Arg = TheCall->getArg(ArgIndex);
+ SourceLocation OrigLoc = Arg->getExprLoc();
+ if (Arg->IgnoreCasts()->isModifiableLvalue(S->Context, &OrigLoc) ==
+ Expr::MLV_Valid)
+ return false;
+ S->Diag(OrigLoc, diag::error_hlsl_inout_lvalue) << Arg << 0;
+ return true;
}
static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
@@ -1742,24 +1762,24 @@ static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
return VecTy->getElementType()->isDoubleType();
return false;
};
- return CheckArgsTypesAreCorrect(S, TheCall, S->Context.FloatTy,
- checkDoubleVector);
+ return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
+ checkDoubleVector);
}
static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasIntegerRepresentation() &&
!PassedType->hasFloatingRepresentation();
};
- return CheckArgsTypesAreCorrect(S, TheCall, S->Context.IntTy,
- checkAllSignedTypes);
+ return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy,
+ checkAllSignedTypes);
}
static bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasUnsignedIntegerRepresentation();
};
- return CheckArgsTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
- checkAllUnsignedTypes);
+ return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
+ checkAllUnsignedTypes);
}
static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
@@ -2074,6 +2094,22 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
+ if (SemaRef.checkArgCount(TheCall, 3))
+ return true;
+
+ if (CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.DoubleTy, 0) ||
+ CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
+ 1) ||
+ CheckScalarOrVector(&SemaRef, TheCall, SemaRef.Context.UnsignedIntTy,
+ 2))
+ return true;
+
+ if (CheckModifiableLValue(&SemaRef, TheCall, 1) ||
+ CheckModifiableLValue(&SemaRef, TheCall, 2))
+ return true;
+ break;
+ }
case Builtin::BI__builtin_elementwise_acos:
case Builtin::BI__builtin_elementwise_asin:
case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
new file mode 100644
index 00000000000000..a883c9d5cc3555
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -0,0 +1,91 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv-vulkan-library %s -fnative-half-type -emit-llvm -O0 -o - | FileCheck %s --check-prefix=SPIRV
+
+
+
+// CHECK: define {{.*}} i32 {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+//
+// SPIRV: define spir_func {{.*}} i32 {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[LOAD:%.*]] = load double, ptr [[VALD]].addr, align 8
+// SPIRV-NEXT: [[CAST:%.*]] = bitcast double [[LOAD]] to <2 x i32>
+// SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 0
+// SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 1
+uint test_scalar(double D) {
+ uint A, B;
+ asuint(D, A, B);
+ return A + B;
+}
+
+// CHECK: define {{.*}} <1 x i32> {{.*}}test_double1{{.*}}(<1 x double> {{.*}} [[VALD:%.*]])
+// CHECK: [[TRUNC:%.*]] = extractelement <1 x double> %D, i64 0
+// CHECK-NEXT: [[VALRET:%.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[TRUNC]])
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+//
+// SPIRV: define spir_func {{.*}} <1 x i32> {{.*}}test_double1{{.*}}(<1 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[LOAD:%.*]] = load <1 x double>, ptr [[VALD]].addr, align 8
+// SPIRV-NEXT: [[TRUNC:%.*]] = extractelement <1 x double> [[LOAD]], i64 0
+// SPIRV-NEXT: [[CAST:%.*]] = bitcast double [[TRUNC]] to <2 x i32>
+// SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 0
+// SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 1
+uint1 test_double1(double1 D) {
+ uint A, B;
+ asuint(D, A, B);
+ return A + B;
+}
+
+// CHECK: define {{.*}} <2 x i32> {{.*}}test_vector2{{.*}}(<2 x double> {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = {{.*}} call { <2 x i32>, <2 x i32> } @llvm.dx.splitdouble.v2i32(<2 x double> [[VALD]])
+// CHECK-NEXT: extractvalue { <2 x i32>, <2 x i32> } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { <2 x i32>, <2 x i32> } [[VALRET]], 1
+//
+// SPIRV: define spir_func {{.*}} <2 x i32> {{.*}}test_vector2{{.*}}(<2 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[LOAD:%.*]] = load <2 x double>, ptr [[VALD]].addr, align 16
+// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <2 x double> [[LOAD]] to <4 x i32>
+// SPIRV-NEXT: [[SHUF1:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
+// SPIRV-NEXT: [[SHUF2:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 1, i32 3>
+uint2 test_vector2(double2 D) {
+ uint2 A, B;
+ asuint(D, A, B);
+ return A + B;
+}
+
+// CHECK: define {{.*}} <3 x i32> {{.*}}test_vector3{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = {{.*}} call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> [[VALD]])
+// CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 1
+//
+// SPIRV: define spir_func {{.*}} <3 x i32> {{.*}}test_vector3{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[LOAD:%.*]] = load <3 x double>, ptr [[VALD]].addr, align 32
+// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <3 x double> [[LOAD]] to <6 x i32>
+// SPIRV-NEXT: [[SHUF1:%.*]] = shufflevector <6 x i32> [[CAST1]], <6 x i32> poison, <3 x i32> <i32 0, i32 2, i32 4>
+// SPIRV-NEXT: [[SHUF2:%.*]] = shufflevector <6 x i32> [[CAST1]], <6 x i32> poison, <3 x i32> <i32 1, i32 3, i32 5>
+uint3 test_vector3(double3 D) {
+ uint3 A, B;
+ asuint(D, A, B);
+ return A + B;
+}
+
+// CHECK: define {{.*}} <4 x i32> {{.*}}test_vector4{{.*}}(<4 x double> {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = {{.*}} call { <4 x i32>, <4 x i32> } @llvm.dx.splitdouble.v4i32(<4 x double> [[VALD]])
+// CHECK-NEXT: extractvalue { <4 x i32>, <4 x i32> } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { <4 x i32>, <4 x i32> } [[VALRET]], 1
+//
+// SPIRV: define spir_func {{.*}} <4 x i32> {{.*}}test_vector4{{.*}}(<4 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[LOAD:%.*]] = load <4 x double>, ptr [[VALD]].addr, align 32
+// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <4 x double> [[LOAD]] to <8 x i32>
+// SPIRV-NEXT: [[SHUF1:%.*]] = shufflevector <8 x i32> [[CAST1]], <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+// SPIRV-NEXT: [[SHUF2:%.*]] = shufflevector <8 x i32> [[CAST1]], <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+uint4 test_vector4(double4 D) {
+ uint4 A, B;
+ asuint(D, A, B);
+ return A + B;
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
index 8c56fdddb1c24c..4adb0555c35be6 100644
--- a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
@@ -6,6 +6,10 @@ uint4 test_asuint_too_many_arg(float p0, float p1) {
// expected-error at -1 {{no matching function for call to 'asuint'}}
// expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires single argument 'V', but 2 arguments were provided}}
// expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires single argument 'F', but 2 arguments were provided}}
+ // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
+ // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
+ // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
+ // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
}
uint test_asuint_double(double p1) {
@@ -23,3 +27,29 @@ uint test_asuint_half(half p1) {
// expected-note at hlsl/hlsl_detail.h:* {{candidate template ignored: could not match 'vector<half, N>' against 'half'}}
// expected-note at hlsl/hlsl_detail.h:* {{candidate template ignored: substitution failure [with U = uint, T = half]: no type named 'Type'}}
}
+
+void test_asuint_first_arg_const(double D) {
+ const uint A = 0;
+ uint B;
+ asuint(D, A, B);
+ // expected-error at hlsl/hlsl_intrinsics.h:* {{read-only variable is not assignable}}
+}
+
+void test_asuint_second_arg_const(double D) {
+ const uint A = 0;
+ uint B;
+ asuint(D, B, A);
+ // expected-error at hlsl/hlsl_intrinsics.h:* {{read-only variable is not assignable}}
+}
+
+void test_asuint_imidiate_value(double D) {
+ uint B;
+ asuint(D, B, 1);
+ // expected-error at -1 {{cannot bind non-lvalue argument 1 to out paramemter}}
+}
+
+void test_asuint_expr(double D) {
+ uint B;
+ asuint(D, B, B + 1);
+ // expected-error at -1 {{cannot bind non-lvalue argument B + 1 to out paramemter}}
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/splitdouble-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/splitdouble-errors.hlsl
new file mode 100644
index 00000000000000..18d2b692b335b9
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/splitdouble-errors.hlsl
@@ -0,0 +1,76 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -verify
+
+void test_no_second_arg(double D) {
+ __builtin_hlsl_elementwise_splitdouble(D);
+ // expected-error at -1 {{too few arguments to function call, expected 3, have 1}}
+}
+
+void test_no_third_arg(double D) {
+ uint A;
+ __builtin_hlsl_elementwise_splitdouble(D, A);
+ // expected-error at -1 {{too few arguments to function call, expected 3, have 2}}
+}
+
+void test_too_many_arg(double D) {
+ uint A, B, C;
+ __builtin_hlsl_elementwise_splitdouble(D, A, B, C);
+ // expected-error at -1 {{too many arguments to function call, expected 3, have 4}}
+}
+
+void test_first_arg_type_mismatch(bool3 D) {
+ uint3 A, B;
+ __builtin_hlsl_elementwise_splitdouble(D, A, B);
+ // expected-error at -1 {{invalid operand of type 'bool3' (aka 'vector<bool, 3>') where 'double' or a vector of such type is required}}
+}
+
+void test_second_arg_type_mismatch(double D) {
+ bool A;
+ uint B;
+ __builtin_hlsl_elementwise_splitdouble(D, A, B);
+ // expected-error at -1 {{invalid operand of type 'bool' where 'unsigned int' or a vector of such type is required}}
+}
+
+void test_third_arg_type_mismatch(double D) {
+ bool A;
+ uint B;
+ __builtin_hlsl_elementwise_splitdouble(D, B, A);
+ // expected-error at -1 {{invalid operand of type 'bool' where 'unsigned int' or a vector of such type is required}}
+}
+
+void test_const_second_arg(double D) {
+ const uint A = 1;
+ uint B;
+ __builtin_hlsl_elementwise_splitdouble(D, A, B);
+ // expected-error at -1 {{cannot bind non-lvalue argument A to out paramemter}}
+}
+
+void test_const_third_arg(double D) {
+ uint A;
+ const uint B = 1;
+ __builtin_hlsl_elementwise_splitdouble(D, A, B);
+ // expected-error at -1 {{cannot bind non-lvalue argument B to out paramemter}}
+}
+
+void test_number_second_arg(double D) {
+ uint B;
+ __builtin_hlsl_elementwise_splitdouble(D, (uint)1, B);
+ // expected-error at -1 {{cannot bind non-lvalue argument (uint)1 to out paramemter}}
+}
+
+void test_number_third_arg(double D) {
+ uint B;
+ __builtin_hlsl_elementwise_splitdouble(D, B, (uint)1);
+ // expected-error at -1 {{cannot bind non-lvalue argument (uint)1 to out paramemter}}
+}
+
+void test_expr_second_arg(double D) {
+ uint B;
+ __builtin_hlsl_elementwise_splitdouble(D, B+1, B);
+ // expected-error at -1 {{cannot bind non-lvalue argument B + 1 to out paramemter}}
+}
+
+void test_expr_third_arg(double D) {
+ uint B;
+ __builtin_hlsl_elementwise_splitdouble(D, B, B+1);
+ // expected-error at -1 {{cannot bind non-lvalue argument B + 1 to out paramemter}}
+}
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..68ae5de06423c2 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -47,6 +47,7 @@ def ResRetInt32Ty : DXILOpParamType;
def HandleTy : DXILOpParamType;
def ResBindTy : DXILOpParamType;
def ResPropsTy : DXILOpParamType;
+def SplitDoubleTy : DXILOpParamType;
class DXILOpClass;
@@ -779,6 +780,15 @@ def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
+def SplitDouble : DXILOp<102, splitDouble> {
+ let Doc = "Splits a double into 2 uints";
+ let arguments = [OverloadTy];
+ let result = SplitDoubleTy;
+ let overloads = [Overloads<DXIL1_0, [DoubleTy]>];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
def AnnotateHandle : DXILOp<217, annotateHandle> {
let Doc = "annotate handle with resource properties";
let arguments = [HandleTy, ResPropsTy];
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 7719d6b1079110..5d5bb3eacace25 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -229,6 +229,13 @@ static StructType *getResPropsType(LLVMContext &Context) {
return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties");
}
+static StructType *getSplitDoubleType(LLVMContext &Context) {
+ if (auto *ST = StructType::getTypeByName(Context, "dx.types.splitdouble"))
+ return ST;
+ Type *Int32Ty = Type::getInt32Ty(Context);
+ return StructType::create({Int32Ty, Int32Ty}, "dx.types.splitdouble");
+}
+
static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
Type *OverloadTy) {
switch (Kind) {
@@ -266,6 +273,8 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
return getResBindType(Ctx);
case OpParamType::ResPropsTy:
return getResPropsType(Ctx);
+ case OpParamType::SplitDoubleTy:
+ return getSplitDoubleType(Ctx);
}
llvm_unreachable("Invalid parameter kind");
return nullptr;
@@ -467,6 +476,10 @@ StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
return ::getResRetType(ElementTy);
}
+StructType *DXILOpBuilder::getSplitDoubleType(LLVMContext &Context) {
+ return ::getSplitDoubleType(Context);
+}
+
StructType *DXILOpBuilder::getHandleType() {
return ::getHandleType(IRB.getContext());
}
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 037ae3822cfb90..df5a0240870f4a 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -49,6 +49,10 @@ class DXILOpBuilder {
/// Get a `%dx.types.ResRet` type with the given element type.
StructType *getResRetType(Type *ElementTy);
+
+ /// Get the `%dx.types.splitdouble` type.
+ StructType *getSplitDoubleType(LLVMContext &Context);
+
/// Get the `%dx.types.Handle` type.
StructType *getHandleType();
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index c62ba8c21d6791..f7722d77074764 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -17,6 +17,7 @@
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/Module.h"
@@ -128,6 +129,30 @@ class OpLowerer {
});
}
+ [[nodiscard]] bool replaceFunctionWithNamedStructOp(
+ Function &F, dxil::OpCode DXILOp, Type *NewRetTy,
+ llvm::function_ref<Error(CallInst *CI, CallInst *Op)> ReplaceUses) {
+ bool IsVectorArgExpansion = isVectorArgExpansion(F);
+ return replaceFunction(F, [&](CallInst *CI) -> Error {
+ SmallVector<Value *> Args;
+ OpBuilder.getIRB().SetInsertPoint(CI);
+ if (IsVectorArgExpansion) {
+ SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB());
+ Args.append(NewArgs.begin(), NewArgs.end());
+ } else
+ Args.append(CI->arg_begin(), CI->arg_end());
+
+ Expected<CallInst *> OpCall =
+ OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), NewRetTy);
+ if (Error E = OpCall.takeError())
+ return E;
+ if (Error E = ReplaceUses(CI, *OpCall))
+ return E;
+
+ return Error::success();
+ });
+ }
+
/// Create a cast between a `target("dx")` type and `dx.types.Handle`, which
/// is intended to be removed by the end of lowering. This is used to allow
/// lowering of ops which need to change their return or argument types in a
@@ -263,6 +288,26 @@ class OpLowerer {
return lowerToBindAndAnnotateHandle(F);
}
+ Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
+ for (Use &U : make_early_inc_range(Intrin->uses())) {
+ if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
+
+ if (EVI->getNumIndices() != 1)
+ return createStringError(std::errc::invalid_argument,
+ "Splitdouble has only 2 elements");
+ EVI->setOperand(0, Op);
+ } else {
+ return make_error<StringError>(
+ "Splitdouble use is not ExtractValueInst",
+ inconvertibleErrorCode());
+ }
+ }
+
+ Intrin->eraseFromParent();
+
+ return Error::success();
+ }
+
/// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
/// Since we expect to be post-scalarization, make an effort to avoid vectors.
Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
@@ -488,6 +533,16 @@ class OpLowerer {
case Intrinsic::dx_typedBufferStore:
HasErrors |= lowerTypedBufferStore(F);
break;
+ // TODO: this can be removed when
+ // https://github.com/llvm/llvm-project/issues/113192 is fixed
+ case Intrinsic::dx_splitdouble:
+ HasErrors |= replaceFunctionWithNamedStructOp(
+ F, OpCode::SplitDouble,
+ OpBuilder.getSplitDoubleType(M.getContext()),
+ [&](CallInst *CI, CallInst *Op) {
+ return replaceSplitDoubleCallUsages(CI, Op);
+ });
+ break;
}
Updated = true;
}
diff --git a/llvm/test/CodeGen/DirectX/split-double.ll b/llvm/test/CodeGen/DirectX/split-double.ll
deleted file mode 100644
index 759590fa56279b..00000000000000
--- a/llvm/test/CodeGen/DirectX/split-double.ll
+++ /dev/null
@@ -1,45 +0,0 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
-; RUN: opt -passes='function(scalarizer)' -S -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
-
-define void @test_vector_double_split_void(<2 x double> noundef %d) {
-; CHECK-LABEL: define void @test_vector_double_split_void(
-; CHECK-SAME: <2 x double> noundef [[D:%.*]]) {
-; CHECK-NEXT: [[D_I0:%.*]] = extractelement <2 x double> [[D]], i64 0
-; CHECK-NEXT: [[HLSL_ASUINT_I0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I0]])
-; CHECK-NEXT: [[D_I1:%.*]] = extractelement <2 x double> [[D]], i64 1
-; CHECK-NEXT: [[HLSL_ASUINT_I1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I1]])
-; CHECK-NEXT: ret void
-;
- %hlsl.asuint = call { <2 x i32>, <2 x i32> } @llvm.dx.splitdouble.v2i32(<2 x double> %d)
- ret void
-}
-
-define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) {
-; CHECK-LABEL: define noundef <3 x i32> @test_vector_double_split(
-; CHECK-SAME: <3 x double> noundef [[D:%.*]]) {
-; CHECK-NEXT: [[D_I0:%.*]] = extractelement <3 x double> [[D]], i64 0
-; CHECK-NEXT: [[HLSL_ASUINT_I0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I0]])
-; CHECK-NEXT: [[D_I1:%.*]] = extractelement <3 x double> [[D]], i64 1
-; CHECK-NEXT: [[HLSL_ASUINT_I1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I1]])
-; CHECK-NEXT: [[D_I2:%.*]] = extractelement <3 x double> [[D]], i64 2
-; CHECK-NEXT: [[HLSL_ASUINT_I2:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I2]])
-; CHECK-NEXT: [[DOTELEM0:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 0
-; CHECK-NEXT: [[DOTELEM01:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I1]], 0
-; CHECK-NEXT: [[DOTELEM02:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I2]], 0
-; CHECK-NEXT: [[DOTELEM1:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 1
-; CHECK-NEXT: [[DOTELEM13:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I1]], 1
-; CHECK-NEXT: [[DOTELEM14:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I2]], 1
-; CHECK-NEXT: [[DOTI0:%.*]] = add i32 [[DOTELEM0]], [[DOTELEM1]]
-; CHECK-NEXT: [[DOTI1:%.*]] = add i32 [[DOTELEM01]], [[DOTELEM13]]
-; CHECK-NEXT: [[DOTI2:%.*]] = add i32 [[DOTELEM02]], [[DOTELEM14]]
-; CHECK-NEXT: [[DOTUPTO015:%.*]] = insertelement <3 x i32> poison, i32 [[DOTI0]], i64 0
-; CHECK-NEXT: [[DOTUPTO116:%.*]] = insertelement <3 x i32> [[DOTUPTO015]], i32 [[DOTI1]], i64 1
-; CHECK-NEXT: [[TMP1:%.*]] = insertelement <3 x i32> [[DOTUPTO116]], i32 [[DOTI2]], i64 2
-; CHECK-NEXT: ret <3 x i32> [[TMP1]]
-;
- %hlsl.asuint = call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %d)
- %1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
- %2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
- %3 = add <3 x i32> %1, %2
- ret <3 x i32> %3
-}
diff --git a/llvm/test/CodeGen/DirectX/splitdouble.ll b/llvm/test/CodeGen/DirectX/splitdouble.ll
new file mode 100644
index 00000000000000..1443ba6269255a
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/splitdouble.ll
@@ -0,0 +1,76 @@
+; RUN: opt -passes='function(scalarizer)' -S -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,NOLOWER
+; RUN: opt -passes='function(scalarizer),module(dxil-op-lower)' -S -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,WITHLOWER
+
+define i32 @test_scalar(double noundef %D) {
+; CHECK-LABEL: define i32 @test_scalar(
+; CHECK-SAME: double noundef [[D:%.*]]) {
+; NOLOWER-NEXT: [[HLSL_ASUINT_I0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D]])
+; WITHLOWER-NEXT: [[HLSL_ASUINT_I0:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double [[D]])
+; NOLOWER-NEXT: [[EV1:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 0
+; NOLOWER-NEXT: [[EV2:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 1
+; WITHLOWER-NEXT: [[EV1:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I0]], 0
+; WITHLOWER-NEXT: [[EV2:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I0]], 1
+; CHECK-NEXT: [[ADD:%.*]] = add i32 [[EV1]], [[EV2]]
+; CHECK-NEXT: ret i32 [[ADD]]
+;
+ %hlsl.splitdouble = call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
+ %1 = extractvalue { i32, i32 } %hlsl.splitdouble, 0
+ %2 = extractvalue { i32, i32 } %hlsl.splitdouble, 1
+ %add = add i32 %1, %2
+ ret i32 %add
+}
+
+
+define void @test_vector_double_split_void(<2 x double> noundef %d) {
+; CHECK-LABEL: define void @test_vector_double_split_void(
+; CHECK-SAME: <2 x double> noundef [[D:%.*]]) {
+; CHECK-NEXT: [[D_I0:%.*]] = extractelement <2 x double> [[D]], i64 0
+; NOLOWER-NEXT: [[HLSL_ASUINT_I0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I0]])
+; WITHLOWER-NEXT: [[HLSL_ASUINT_I0:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double [[D_I0]])
+; CHECK-NEXT: [[D_I1:%.*]] = extractelement <2 x double> [[D]], i64 1
+; NOLOWER-NEXT: [[HLSL_ASUINT_I1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I1]])
+; WITHLOWER-NEXT: [[HLSL_ASUINT_I1:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double [[D_I1]])
+; CHECK-NEXT: ret void
+;
+ %hlsl.asuint = call { <2 x i32>, <2 x i32> } @llvm.dx.splitdouble.v2i32(<2 x double> %d)
+ ret void
+}
+
+define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %d) {
+; CHECK-LABEL: define noundef <3 x i32> @test_vector_double_split(
+; CHECK-SAME: <3 x double> noundef [[D:%.*]]) {
+; CHECK-NEXT: [[D_I0:%.*]] = extractelement <3 x double> [[D]], i64 0
+; NOLOWER-NEXT: [[HLSL_ASUINT_I0:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I0]])
+; WITHLOWER-NEXT: [[HLSL_ASUINT_I0:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double [[D_I0]])
+; CHECK-NEXT: [[D_I1:%.*]] = extractelement <3 x double> [[D]], i64 1
+; NOLOWER-NEXT: [[HLSL_ASUINT_I1:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I1]])
+; WITHLOWER-NEXT: [[HLSL_ASUINT_I1:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double [[D_I1]])
+; CHECK-NEXT: [[D_I2:%.*]] = extractelement <3 x double> [[D]], i64 2
+; NOLOWER-NEXT: [[HLSL_ASUINT_I2:%.*]] = call { i32, i32 } @llvm.dx.splitdouble.i32(double [[D_I2]])
+; WITHLOWER-NEXT: [[HLSL_ASUINT_I2:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double [[D_I2]])
+; NOLOWER-NEXT: [[DOTELEM0:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 0
+; WITHLOWER-NEXT: [[DOTELEM0:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I0]], 0
+; NOLOWER-NEXT: [[DOTELEM01:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I1]], 0
+; WITHLOWER-NEXT: [[DOTELEM01:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I1]], 0
+; NOLOWER-NEXT: [[DOTELEM02:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I2]], 0
+; WITHLOWER-NEXT: [[DOTELEM02:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I2]], 0
+; NOLOWER-NEXT: [[DOTELEM1:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I0]], 1
+; WITHLOWER-NEXT: [[DOTELEM1:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I0]], 1
+; NOLOWER-NEXT: [[DOTELEM13:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I1]], 1
+; WITHLOWER-NEXT: [[DOTELEM13:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I1]], 1
+; NOLOWER-NEXT: [[DOTELEM14:%.*]] = extractvalue { i32, i32 } [[HLSL_ASUINT_I2]], 1
+; WITHLOWER-NEXT: [[DOTELEM14:%.*]] = extractvalue %dx.types.splitdouble [[HLSL_ASUINT_I2]], 1
+; CHECK-NEXT: [[DOTI0:%.*]] = add i32 [[DOTELEM0]], [[DOTELEM1]]
+; CHECK-NEXT: [[DOTI1:%.*]] = add i32 [[DOTELEM01]], [[DOTELEM13]]
+; CHECK-NEXT: [[DOTI2:%.*]] = add i32 [[DOTELEM02]], [[DOTELEM14]]
+; CHECK-NEXT: [[DOTUPTO015:%.*]] = insertelement <3 x i32> poison, i32 [[DOTI0]], i64 0
+; CHECK-NEXT: [[DOTUPTO116:%.*]] = insertelement <3 x i32> [[DOTUPTO015]], i32 [[DOTI1]], i64 1
+; CHECK-NEXT: [[TMP1:%.*]] = insertelement <3 x i32> [[DOTUPTO116]], i32 [[DOTI2]], i64 2
+; CHECK-NEXT: ret <3 x i32> [[TMP1]]
+;
+ %hlsl.asuint = call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %d)
+ %1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 0
+ %2 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.asuint, 1
+ %3 = add <3 x i32> %1, %2
+ ret <3 x i32> %3
+}
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
new file mode 100644
index 00000000000000..d18b16b843c37b
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
@@ -0,0 +1,40 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure lowering is correctly generating spirv code.
+
+; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#vec_2_double:]] = OpTypeVector %[[#double]] 2
+; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#vec_2_int_32:]] = OpTypeVector %[[#int_32]] 2
+; CHECK-DAG: %[[#vec_4_int_32:]] = OpTypeVector %[[#int_32]] 4
+
+
+define spir_func noundef i32 @test_scalar(double noundef %D) local_unnamed_addr {
+entry:
+ ; CHECK-LABEL: ; -- Begin function test_scalar
+ ; CHECK: %[[#param:]] = OpFunctionParameter %[[#double]]
+ ; CHECK: %[[#bitcast:]] = OpBitcast %[[#vec_2_int_32]] %[[#param]]
+ %0 = bitcast double %D to <2 x i32>
+ ; CHECK: %[[#]] = OpCompositeExtract %[[#int_32]] %[[#bitcast]] 0
+ %1 = extractelement <2 x i32> %0, i64 0
+ ; CHECK: %[[#]] = OpCompositeExtract %[[#int_32]] %[[#bitcast]] 1
+ %2 = extractelement <2 x i32> %0, i64 1
+ %add = add i32 %1, %2
+ ret i32 %add
+}
+
+
+define spir_func noundef <2 x i32> @test_vector(<2 x double> noundef %D) local_unnamed_addr {
+entry:
+ ; CHECK-LABEL: ; -- Begin function test_vector
+ ; CHECK: %[[#param:]] = OpFunctionParameter %[[#vec_2_double]]
+ ; CHECK: %[[#CAST1:]] = OpBitcast %[[#vec_4_int_32]] %[[#param]]
+ ; CHECK: %[[#SHUFF2:]] = OpVectorShuffle %[[#vec_2_int_32]] %[[#CAST1]] %[[#]] 0 2
+ ; CHECK: %[[#SHUFF3:]] = OpVectorShuffle %[[#vec_2_int_32]] %[[#CAST1]] %[[#]] 1 3
+ %0 = bitcast <2 x double> %D to <4 x i32>
+ %1 = shufflevector <4 x i32> %0, <4 x i32> poison, <2 x i32> <i32 0, i32 2>
+ %2 = shufflevector <4 x i32> %0, <4 x i32> poison, <2 x i32> <i32 1, i32 3>
+ %add = add <2 x i32> %1, %2
+ ret <2 x i32> %add
+}
More information about the cfe-commits
mailing list