[clang] [llvm] Adding splitdouble HLSL function (PR #109331)
via cfe-commits
cfe-commits at lists.llvm.org
Wed Sep 25 22:32:35 PDT 2024
https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/109331
>From 50d21754119ac10c2ee2376ed8f79d12f73cd137 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 19 Sep 2024 00:13:51 +0000
Subject: [PATCH 1/5] Codegen builtin
---
clang/include/clang/Basic/Builtins.td | 6 ++
clang/lib/CodeGen/CGBuiltin.cpp | 38 ++++++++++++
clang/lib/CodeGen/CGCall.cpp | 5 ++
clang/lib/CodeGen/CGExpr.cpp | 15 ++++-
clang/lib/CodeGen/CodeGenFunction.h | 10 +++-
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 20 +++++++
clang/lib/Sema/SemaHLSL.cpp | 58 ++++++++++++++++---
.../builtins/asuint-splitdouble.hlsl | 10 ++++
llvm/include/llvm/IR/IntrinsicsDirectX.td | 5 ++
llvm/lib/Target/DirectX/DXIL.td | 1 +
.../Target/DirectX/DXILIntrinsicExpansion.cpp | 13 +++++
11 files changed, 167 insertions(+), 14 deletions(-)
create mode 100644 clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 8c5d7ad763bf97..b38957f6e3f15d 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4788,6 +4788,12 @@ def HLSLStep: LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLAsUintSplitDouble: LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_asuint_splitdouble"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 249aead33ad73d..f7695b8693f3dc 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18843,6 +18843,44 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
retType, CGM.getHLSLRuntime().getSignIntrinsic(),
ArrayRef<Value *>{Op0}, nullptr, "hlsl.sign");
}
+ // This should only be called when targeting DXIL
+ case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+
+ assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
+ E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
+ E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
+ "asuint operands types mismatch");
+
+ Value *Op0 = EmitScalarExpr(E->getArg(0));
+ const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
+ const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
+
+ CallArgList Args;
+ LValue Op1TmpLValue = EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+ LValue Op2TmpLValue = EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+
+ llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty);
+ if (Op0->getType()->isVectorTy()) {
+ auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
+
+ llvm::VectorType *i32VecTy = llvm::VectorType::get(
+ Int32Ty, ElementCount::getFixed(XVecTy->getNumElements()));
+
+ retType = llvm::StructType::get(i32VecTy, i32VecTy);
+ }
+
+ CallInst *CI =
+ Builder.CreateIntrinsic(retType, llvm::Intrinsic::dx_asuint_splitdouble,
+ {Op0}, nullptr, "hlsl.asuint");
+
+ Value *arg0 = Builder.CreateExtractValue(CI, 0);
+ Value *arg1 = Builder.CreateExtractValue(CI, 1);
+
+ Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+ auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
+ EmitWritebacks(*this, Args);
+ return s;
+ }
}
return nullptr;
}
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 4ae981e4013e9c..096bbafa4cc694 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4681,6 +4681,11 @@ void CallArg::copyInto(CodeGenFunction &CGF, Address Addr) const {
IsUsed = true;
}
+void CodeGenFunction::EmitWritebacks(CodeGenFunction &CGF,
+ const CallArgList &args) {
+ emitWritebacks(CGF, args);
+}
+
void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
QualType type) {
DisableDebugLocationUpdates Dis(*this, E);
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 9166db4c74128c..1c299c4a932ca0 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -19,6 +19,7 @@
#include "CGObjCRuntime.h"
#include "CGOpenMPRuntime.h"
#include "CGRecordLayout.h"
+#include "CGValue.h"
#include "CodeGenFunction.h"
#include "CodeGenModule.h"
#include "ConstantEmitter.h"
@@ -28,6 +29,7 @@
#include "clang/AST/DeclObjC.h"
#include "clang/AST/NSAPI.h"
#include "clang/AST/StmtVisitor.h"
+#include "clang/AST/Type.h"
#include "clang/Basic/Builtins.h"
#include "clang/Basic/CodeGenOptions.h"
#include "clang/Basic/SourceManager.h"
@@ -5458,9 +5460,8 @@ LValue CodeGenFunction::EmitOpaqueValueLValue(const OpaqueValueExpr *e) {
return getOrCreateOpaqueLValueMapping(e);
}
-void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
- CallArgList &Args, QualType Ty) {
-
+std::pair<LValue, LValue>
+CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) {
// Emitting the casted temporary through an opaque value.
LValue BaseLV = EmitLValue(E->getArgLValue());
OpaqueValueMappingData::bind(*this, E->getOpaqueArgLValue(), BaseLV);
@@ -5474,6 +5475,13 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
TempLV);
OpaqueValueMappingData::bind(*this, E->getCastedTemporary(), TempLV);
+ return std::make_pair(BaseLV, TempLV);
+}
+
+LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
+ CallArgList &Args, QualType Ty) {
+
+ auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty);
llvm::Value *Addr = TempLV.getAddress().getBasePointer();
llvm::Type *ElTy = ConvertTypeForMem(TempLV.getType());
@@ -5486,6 +5494,7 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(),
LifetimeSize);
Args.add(RValue::get(TmpAddr, *this), Ty);
+ return TempLV;
}
LValue
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 3e2abbd9bc1094..ad7c2635500d93 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4296,8 +4296,11 @@ class CodeGenFunction : public CodeGenTypeCache {
LValue EmitCastLValue(const CastExpr *E);
LValue EmitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *E);
LValue EmitOpaqueValueLValue(const OpaqueValueExpr *e);
- void EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
- QualType Ty);
+
+ std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E,
+ QualType Ty);
+ LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
+ QualType Ty);
Address EmitExtVectorElementLValue(LValue V);
@@ -5147,6 +5150,9 @@ class CodeGenFunction : public CodeGenTypeCache {
SourceLocation ArgLoc, AbstractCallee AC,
unsigned ParmNum);
+ /// EmitWriteback - Emit callbacks for function.
+ void EmitWritebacks(CodeGenFunction &CGF, const CallArgList &args);
+
/// EmitCallArg - Emit a single call argument.
void EmitCallArg(CallArgList &args, const Expr *E, QualType ArgType);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index b139f9eb7d999b..e8a1e97f344559 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -422,6 +422,26 @@ template <typename T> constexpr uint asuint(T F) {
return __detail::bit_cast<uint, T>(F);
}
+//===----------------------------------------------------------------------===//
+// asuint splitdouble builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn void asuint(double D, out uint lowbits, out int highbits)
+/// \brief Split and interprets the lowbits and highbits of double D into uints.
+/// \param D The input double.
+/// \param lowbits The output lowbits of D.
+/// \param highbits The highbits lowbits D.
+#if __is_target_arch(dxil)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double, out uint, out uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double2, out uint2, out uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double3, out uint3, out uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double4, out uint4, out uint4);
+#endif
+
//===----------------------------------------------------------------------===//
// atan builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index ebe76185cbb2d5..20878442c92338 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1467,18 +1467,27 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
return true;
}
-static bool CheckArgsTypesAreCorrect(
+bool CheckArgTypeIsCorrect(
+ Sema *S, Expr *Arg, QualType ExpectedType,
+ llvm::function_ref<bool(clang::QualType PassedType)> Check) {
+ QualType PassedType = Arg->getType();
+ if (Check(PassedType)) {
+ if (auto *VecTyA = PassedType->getAs<VectorType>())
+ ExpectedType = S->Context.getVectorType(
+ ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
+ S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
+ << PassedType << ExpectedType << 1 << 0 << 0;
+ return true;
+ }
+ return false;
+}
+
+bool CheckArgsTypesAreCorrect(
Sema *S, CallExpr *TheCall, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
- QualType PassedType = TheCall->getArg(i)->getType();
- if (Check(PassedType)) {
- if (auto *VecTyA = PassedType->getAs<VectorType>())
- ExpectedType = S->Context.getVectorType(
- ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
- S->Diag(TheCall->getArg(0)->getBeginLoc(),
- diag::err_typecheck_convert_incompatible)
- << PassedType << ExpectedType << 1 << 0 << 0;
+ Expr *Arg = TheCall->getArg(i);
+ if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
return true;
}
}
@@ -1762,6 +1771,37 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+ if (SemaRef.checkArgCount(TheCall, 3))
+ return true;
+
+ // Expr *Op0 = TheCall->getArg(0);
+
+ // auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool {
+ // return !PassedType->isDoubleType();
+ // };
+
+ // if (CheckArgTypeIsCorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
+ // CheckIsNotDouble)) {
+ // return true;
+ // }
+
+ // Expr *Op1 = TheCall->getArg(1);
+ // Expr *Op2 = TheCall->getArg(2);
+
+ // auto CheckIsNotUint = [](clang::QualType PassedType) -> bool {
+ // return !PassedType->isUnsignedIntegerType();
+ // };
+
+ // if (CheckArgTypeIsCorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
+ // CheckIsNotUint) ||
+ // CheckArgTypeIsCorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
+ // CheckIsNotUint)) {
+ // return true;
+ // }
+
+ break;
+ }
case Builtin::BI__builtin_elementwise_acos:
case Builtin::BI__builtin_elementwise_asin:
case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
new file mode 100644
index 00000000000000..e359354dc3a6df
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
@@ -0,0 +1,10 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O0 -o - | FileCheck %s
+
+// CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}}
+// CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}}
+// CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]])
+float fn(double D) {
+ uint A, B;
+ asuint(D, A, B);
+ return A + B;
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 3ce7b8b987ef86..d8092397881550 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -88,4 +88,9 @@ def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
+
+def int_dx_asuint_splitdouble : DefaultAttrsIntrinsic<
+ [llvm_anyint_ty, LLVMMatchType<0>],
+ [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>],
+ [IntrNoMem, IntrWillReturn]>;
}
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 9aa0af3e3a6b17..06c52da5fc07c8 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -778,6 +778,7 @@ def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> {
let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
+//
def AnnotateHandle : DXILOp<217, annotateHandle> {
let Doc = "annotate handle with resource properties";
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 926cbe97f24fda..09e87d5035093b 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -12,6 +12,7 @@
#include "DXILIntrinsicExpansion.h"
#include "DirectX.h"
+#include "llvm-c/Core.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/DXILResource.h"
@@ -395,6 +396,15 @@ static Value *expandStepIntrinsic(CallInst *Orig) {
return Builder.CreateSelect(Cond, Zero, One);
}
+// static Value *expandSplitdoubleIntrinsic(CallInst *Orig) {
+// Value *X = Orig->getOperand(0);
+// Type *Ty = X->getType();
+// IRBuilder<> Builder(Orig);
+
+// Builder.CreateIntrinsic()
+
+// }
+
static Intrinsic::ID getMaxForClamp(Type *ElemTy,
Intrinsic::ID ClampIntrinsic) {
if (ClampIntrinsic == Intrinsic::dx_uclamp)
@@ -511,6 +521,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
break;
case Intrinsic::dx_step:
Result = expandStepIntrinsic(Orig);
+ break;
+ // case Intrinsic::dx_asuint_splitdouble:
+ // Result = expandSplitdoubleIntrinsic(Orig);
}
if (Result) {
Orig->replaceAllUsesWith(Result);
>From 9a094f4fec017dd2e990b41caf343c1d5081cada Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Mon, 23 Sep 2024 21:19:12 +0000
Subject: [PATCH 2/5] adding vector case for splitdouble
---
clang/lib/CodeGen/CGBuiltin.cpp | 62 ++++++++++++++-----
clang/lib/CodeGen/CGExpr.cpp | 8 ++-
clang/lib/CodeGen/CodeGenFunction.h | 4 +-
.../builtins/asuint-splitdouble.hlsl | 4 +-
4 files changed, 57 insertions(+), 21 deletions(-)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index f7695b8693f3dc..e0c97270f6d1f1 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -34,12 +34,14 @@
#include "clang/Frontend/FrontendDiagnostic.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
@@ -67,6 +69,7 @@
#include "llvm/TargetParser/X86TargetParser.h"
#include <optional>
#include <sstream>
+#include <utility>
using namespace clang;
using namespace CodeGen;
@@ -18855,29 +18858,60 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
+ auto emitSplitDouble =
+ [](CGBuilderTy *Builder, llvm::Value *arg,
+ llvm::Type *retType) -> std::pair<Value *, Value *> {
+ CallInst *CI = Builder->CreateIntrinsic(
+ retType, llvm::Intrinsic::dx_asuint_splitdouble, {arg}, nullptr,
+ "hlsl.asuint");
+
+ Value *arg0 = Builder->CreateExtractValue(CI, 0);
+ Value *arg1 = Builder->CreateExtractValue(CI, 1);
+
+ return std::make_pair(arg0, arg1);
+ };
+
CallArgList Args;
- LValue Op1TmpLValue = EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
- LValue Op2TmpLValue = EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+ auto [Op1BaseLValue, Op1TmpLValue] =
+ EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+ auto [Op2BaseLValue, Op2TmpLValue] =
+ EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty);
- if (Op0->getType()->isVectorTy()) {
- auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
- llvm::VectorType *i32VecTy = llvm::VectorType::get(
- Int32Ty, ElementCount::getFixed(XVecTy->getNumElements()));
+ if (!Op0->getType()->isVectorTy()) {
+ auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType);
+
+ Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+ auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
- retType = llvm::StructType::get(i32VecTy, i32VecTy);
+ EmitWritebacks(*this, Args);
+ return s;
}
- CallInst *CI =
- Builder.CreateIntrinsic(retType, llvm::Intrinsic::dx_asuint_splitdouble,
- {Op0}, nullptr, "hlsl.asuint");
+ auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
+
+ llvm::VectorType *i32VecTy = llvm::VectorType::get(
+ Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
- Value *arg0 = Builder.CreateExtractValue(CI, 0);
- Value *arg1 = Builder.CreateExtractValue(CI, 1);
+ std::pair<Value *, Value *> inserts = std::make_pair(nullptr, nullptr);
+
+ for (uint64_t idx = 0; idx < Op0VecTy->getNumElements(); idx++) {
+ Value *op = Builder.CreateExtractElement(Op0, idx);
+
+ auto [arg0, arg1] = emitSplitDouble(&Builder, op, retType);
+
+ if (idx == 0) {
+ inserts.first = Builder.CreateInsertElement(i32VecTy, arg0, idx);
+ inserts.second = Builder.CreateInsertElement(i32VecTy, arg1, idx);
+ } else {
+ inserts.first = Builder.CreateInsertElement(inserts.first, arg0, idx);
+ inserts.second = Builder.CreateInsertElement(inserts.second, arg0, idx);
+ }
+ }
- Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
- auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
+ Builder.CreateStore(inserts.first, Op1TmpLValue.getAddress());
+ auto *s = Builder.CreateStore(inserts.second, Op2TmpLValue.getAddress());
EmitWritebacks(*this, Args);
return s;
}
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 1c299c4a932ca0..53b60ad477a68b 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -54,6 +54,7 @@
#include <optional>
#include <string>
+#include <utility>
using namespace clang;
using namespace CodeGen;
@@ -5478,8 +5479,9 @@ CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) {
return std::make_pair(BaseLV, TempLV);
}
-LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
- CallArgList &Args, QualType Ty) {
+std::pair<LValue, LValue>
+CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
+ QualType Ty) {
auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty);
@@ -5494,7 +5496,7 @@ LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(),
LifetimeSize);
Args.add(RValue::get(TmpAddr, *this), Ty);
- return TempLV;
+ return std::make_pair(BaseLV, TempLV);
}
LValue
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index ad7c2635500d93..7372faa5656121 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4299,8 +4299,8 @@ class CodeGenFunction : public CodeGenTypeCache {
std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E,
QualType Ty);
- LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
- QualType Ty);
+ std::pair<LValue, LValue> EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
+ CallArgList &Args, QualType Ty);
Address EmitExtVectorElementLValue(LValue V);
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
index e359354dc3a6df..4326612db96b0f 100644
--- a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
@@ -3,8 +3,8 @@
// CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}}
// CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}}
// CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]])
-float fn(double D) {
- uint A, B;
+float2 fn(double2 D) {
+ uint2 A, B;
asuint(D, A, B);
return A + B;
}
>From 81476c7ad27010600dc4b4be1d66e7c7db7c10fb Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Tue, 24 Sep 2024 00:50:10 +0000
Subject: [PATCH 3/5] adding lowering to dxil
---
clang/include/clang/Basic/Builtins.td | 4 +-
clang/lib/CodeGen/CGBuiltin.cpp | 15 +++---
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 8 +--
clang/lib/Sema/SemaHLSL.cpp | 44 ++++++++--------
.../builtins/asuint-splitdouble.hlsl | 25 +++++++---
.../test/SemaHLSL/BuiltIns/asuint-errors.hlsl | 4 ++
llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 +-
llvm/lib/Target/DirectX/DXIL.td | 11 +++-
.../Target/DirectX/DXILIntrinsicExpansion.cpp | 13 -----
llvm/lib/Target/DirectX/DXILOpBuilder.cpp | 13 +++++
llvm/lib/Target/DirectX/DXILOpBuilder.h | 4 ++
llvm/lib/Target/DirectX/DXILOpLowering.cpp | 50 +++++++++++++++++++
12 files changed, 135 insertions(+), 58 deletions(-)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b38957f6e3f15d..4e0566615b5fef 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4788,8 +4788,8 @@ def HLSLStep: LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
-def HLSLAsUintSplitDouble: LangBuiltin<"HLSL_LANG"> {
- let Spellings = ["__builtin_hlsl_asuint_splitdouble"];
+def HLSLSplitDouble: LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_splitdouble"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e0c97270f6d1f1..e9c44be58289af 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -34,14 +34,12 @@
#include "clang/Frontend/FrontendDiagnostic.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/InlineAsm.h"
-#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
@@ -69,7 +67,6 @@
#include "llvm/TargetParser/X86TargetParser.h"
#include <optional>
#include <sstream>
-#include <utility>
using namespace clang;
using namespace CodeGen;
@@ -18847,7 +18844,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
ArrayRef<Value *>{Op0}, nullptr, "hlsl.sign");
}
// This should only be called when targeting DXIL
- case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+ case Builtin::BI__builtin_hlsl_splitdouble: {
assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
@@ -18861,9 +18858,9 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
auto emitSplitDouble =
[](CGBuilderTy *Builder, llvm::Value *arg,
llvm::Type *retType) -> std::pair<Value *, Value *> {
- CallInst *CI = Builder->CreateIntrinsic(
- retType, llvm::Intrinsic::dx_asuint_splitdouble, {arg}, nullptr,
- "hlsl.asuint");
+ CallInst *CI =
+ Builder->CreateIntrinsic(retType, llvm::Intrinsic::dx_splitdouble,
+ {arg}, nullptr, "hlsl.asuint");
Value *arg0 = Builder->CreateExtractValue(CI, 0);
Value *arg1 = Builder->CreateExtractValue(CI, 1);
@@ -18877,7 +18874,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
auto [Op2BaseLValue, Op2TmpLValue] =
EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
- llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty);
+ llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
if (!Op0->getType()->isVectorTy()) {
auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType);
@@ -18906,7 +18903,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
inserts.second = Builder.CreateInsertElement(i32VecTy, arg1, idx);
} else {
inserts.first = Builder.CreateInsertElement(inserts.first, arg0, idx);
- inserts.second = Builder.CreateInsertElement(inserts.second, arg0, idx);
+ inserts.second = Builder.CreateInsertElement(inserts.second, arg1, idx);
}
}
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index e8a1e97f344559..dede9583d1bc58 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -432,13 +432,13 @@ template <typename T> constexpr uint asuint(T F) {
/// \param lowbits The output lowbits of D.
/// \param highbits The highbits lowbits D.
#if __is_target_arch(dxil)
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
void asuint(double, out uint, out uint);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
void asuint(double2, out uint2, out uint2);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
void asuint(double3, out uint3, out uint3);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
void asuint(double4, out uint4, out uint4);
#endif
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 20878442c92338..c0bbe00fe47008 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1467,7 +1467,7 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
return true;
}
-bool CheckArgTypeIsCorrect(
+bool CheckArgTypeIsIncorrect(
Sema *S, Expr *Arg, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
QualType PassedType = Arg->getType();
@@ -1487,7 +1487,7 @@ bool CheckArgsTypesAreCorrect(
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
Expr *Arg = TheCall->getArg(i);
- if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
+ if (CheckArgTypeIsIncorrect(S, Arg, ExpectedType, Check)) {
return true;
}
}
@@ -1771,34 +1771,34 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
- case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+ case Builtin::BI__builtin_hlsl_splitdouble: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;
- // Expr *Op0 = TheCall->getArg(0);
+ Expr *Op0 = TheCall->getArg(0);
- // auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool {
- // return !PassedType->isDoubleType();
- // };
+ auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool {
+ return !PassedType->hasFloatingRepresentation();
+ };
- // if (CheckArgTypeIsCorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
- // CheckIsNotDouble)) {
- // return true;
- // }
+ if (CheckArgTypeIsIncorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
+ CheckIsNotDouble)) {
+ return true;
+ }
- // Expr *Op1 = TheCall->getArg(1);
- // Expr *Op2 = TheCall->getArg(2);
+ Expr *Op1 = TheCall->getArg(1);
+ Expr *Op2 = TheCall->getArg(2);
- // auto CheckIsNotUint = [](clang::QualType PassedType) -> bool {
- // return !PassedType->isUnsignedIntegerType();
- // };
+ auto CheckIsNotUint = [](clang::QualType PassedType) -> bool {
+ return !PassedType->hasUnsignedIntegerRepresentation();
+ };
- // if (CheckArgTypeIsCorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
- // CheckIsNotUint) ||
- // CheckArgTypeIsCorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
- // CheckIsNotUint)) {
- // return true;
- // }
+ if (CheckArgTypeIsIncorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
+ CheckIsNotUint) ||
+ CheckArgTypeIsIncorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
+ CheckIsNotUint)) {
+ return true;
+ }
break;
}
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
index 4326612db96b0f..1711c344792aee 100644
--- a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
@@ -1,10 +1,23 @@
-// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O0 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
-// CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}}
-// CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}}
-// CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]])
-float2 fn(double2 D) {
- uint2 A, B;
+
+// CHECK: define {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+float test_scalar(double D) {
+ uint A, B;
+ asuint(D, A, B);
+ return A + B;
+}
+
+// CHECK: define {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[VALD]], i64 [[VALIDX:[0-3]]]
+// CHECK-NEXT: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALREG]])
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+float3 test_vector(double3 D) {
+ uint3 A, B;
asuint(D, A, B);
return A + B;
}
diff --git a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
index 8c56fdddb1c24c..b9a920f9f1b4d0 100644
--- a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
@@ -6,6 +6,10 @@ uint4 test_asuint_too_many_arg(float p0, float p1) {
// expected-error at -1 {{no matching function for call to 'asuint'}}
// expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires single argument 'V', but 2 arguments were provided}}
// expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires single argument 'F', but 2 arguments were provided}}
+ // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
+ // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
+ // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
+ // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
}
uint test_asuint_double(double p1) {
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index d8092397881550..04dd26ea54ca80 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -89,7 +89,7 @@ def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrCon
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
-def int_dx_asuint_splitdouble : DefaultAttrsIntrinsic<
+def int_dx_splitdouble : DefaultAttrsIntrinsic<
[llvm_anyint_ty, LLVMMatchType<0>],
[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>],
[IntrNoMem, IntrWillReturn]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 06c52da5fc07c8..912d385fe285a2 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -47,6 +47,7 @@ def ResRetInt32Ty : DXILOpParamType;
def HandleTy : DXILOpParamType;
def ResBindTy : DXILOpParamType;
def ResPropsTy : DXILOpParamType;
+def ResSplitDoubleTy : DXILOpParamType;
class DXILOpClass;
@@ -778,7 +779,15 @@ def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> {
let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
-//
+
+def SplitDouble : DXILOp<102, splitDouble> {
+ let Doc = "Splits a double into 2 uints";
+ let arguments = [OverloadTy];
+ let result = ResSplitDoubleTy;
+ let overloads = [Overloads<DXIL1_0, [DoubleTy]>];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
def AnnotateHandle : DXILOp<217, annotateHandle> {
let Doc = "annotate handle with resource properties";
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 09e87d5035093b..926cbe97f24fda 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -12,7 +12,6 @@
#include "DXILIntrinsicExpansion.h"
#include "DirectX.h"
-#include "llvm-c/Core.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/DXILResource.h"
@@ -396,15 +395,6 @@ static Value *expandStepIntrinsic(CallInst *Orig) {
return Builder.CreateSelect(Cond, Zero, One);
}
-// static Value *expandSplitdoubleIntrinsic(CallInst *Orig) {
-// Value *X = Orig->getOperand(0);
-// Type *Ty = X->getType();
-// IRBuilder<> Builder(Orig);
-
-// Builder.CreateIntrinsic()
-
-// }
-
static Intrinsic::ID getMaxForClamp(Type *ElemTy,
Intrinsic::ID ClampIntrinsic) {
if (ClampIntrinsic == Intrinsic::dx_uclamp)
@@ -521,9 +511,6 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
break;
case Intrinsic::dx_step:
Result = expandStepIntrinsic(Orig);
- break;
- // case Intrinsic::dx_asuint_splitdouble:
- // Result = expandSplitdoubleIntrinsic(Orig);
}
if (Result) {
Orig->replaceAllUsesWith(Result);
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 7719d6b1079110..982d7849d9bb8b 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -229,6 +229,13 @@ static StructType *getResPropsType(LLVMContext &Context) {
return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties");
}
+static StructType *getResSplitDoubleType(LLVMContext &Context) {
+ if (auto *ST = StructType::getTypeByName(Context, "dx.types.splitdouble"))
+ return ST;
+ Type *Int32Ty = Type::getInt32Ty(Context);
+ return StructType::create({Int32Ty, Int32Ty}, "dx.types.splitdouble");
+}
+
static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
Type *OverloadTy) {
switch (Kind) {
@@ -266,6 +273,8 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
return getResBindType(Ctx);
case OpParamType::ResPropsTy:
return getResPropsType(Ctx);
+ case OpParamType::ResSplitDoubleTy:
+ return getResSplitDoubleType(Ctx);
}
llvm_unreachable("Invalid parameter kind");
return nullptr;
@@ -467,6 +476,10 @@ StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
return ::getResRetType(ElementTy);
}
+StructType *DXILOpBuilder::getResSplitDoubleType(LLVMContext &Context) {
+ return ::getResSplitDoubleType(Context);
+}
+
StructType *DXILOpBuilder::getHandleType() {
return ::getHandleType(IRB.getContext());
}
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 037ae3822cfb90..8b1e87c283146c 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -49,6 +49,10 @@ class DXILOpBuilder {
/// Get a `%dx.types.ResRet` type with the given element type.
StructType *getResRetType(Type *ElementTy);
+
+ /// Get the `%dx.types.splitdouble` type.
+ StructType *getResSplitDoubleType(LLVMContext &Context);
+
/// Get the `%dx.types.Handle` type.
StructType *getHandleType();
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 3ee3ee05563c24..83c6b7f6d503dc 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -17,6 +17,7 @@
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/Module.h"
@@ -264,6 +265,31 @@ class OpLowerer {
return lowerToBindAndAnnotateHandle(F);
}
+ Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
+ IRBuilder<> &IRB = OpBuilder.getIRB();
+
+ for (Use &U : Intrin->uses()) {
+ if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
+
+ assert(EVI->getNumIndices() == 1 &&
+ "splitdouble result should be indexed individually.");
+ if (EVI->getNumIndices() != 1)
+ return make_error<StringError>(
+ "splitdouble result should be indexed individually.",
+ inconvertibleErrorCode());
+
+ unsigned int IndexVal = EVI->getIndices()[0];
+
+ auto *OpEVI = IRB.CreateExtractValue(Op, IndexVal);
+ EVI->replaceAllUsesWith(OpEVI);
+ EVI->eraseFromParent();
+ }
+ }
+ Intrin->eraseFromParent();
+
+ return Error::success();
+ }
+
/// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
/// Since we expect to be post-scalarization, make an effort to avoid vectors.
Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
@@ -461,6 +487,27 @@ class OpLowerer {
});
}
+ [[nodiscard]] bool lowerSplitDouble(Function &F) {
+ IRBuilder<> &IRB = OpBuilder.getIRB();
+ return replaceFunction(F, [&](CallInst *CI) -> Error {
+ IRB.SetInsertPoint(CI);
+
+ Value *Arg0 = CI->getArgOperand(0);
+
+ Type *NewRetTy = OpBuilder.getResSplitDoubleType(M.getContext());
+
+ std::array<Value *, 1> Args{Arg0};
+ Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
+ OpCode::SplitDouble, Args, CI->getName(), NewRetTy);
+ if (Error E = OpCall.takeError())
+ return E;
+ if (Error E = replaceSplitDoubleCallUsages(CI, *OpCall))
+ return E;
+
+ return Error::success();
+ });
+ }
+
bool lowerIntrinsics() {
bool Updated = false;
bool HasErrors = false;
@@ -489,6 +536,9 @@ class OpLowerer {
case Intrinsic::dx_typedBufferStore:
HasErrors |= lowerTypedBufferStore(F);
break;
+ case Intrinsic::dx_splitdouble:
+ HasErrors |= lowerSplitDouble(F);
+ break;
}
Updated = true;
}
>From 1afcec6debfe7a30424d2fdf54f4a6a5d3558875 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 26 Sep 2024 00:20:54 +0000
Subject: [PATCH 4/5] adding tests
---
clang/lib/CodeGen/CGExpr.cpp | 3 -
...uint-splitdouble.hlsl => splitdouble.hlsl} | 0
llvm/lib/Target/DirectX/DXILOpLowering.cpp | 2 +-
llvm/test/CodeGen/DirectX/splitdouble.ll | 63 +++++++++++++++++++
4 files changed, 64 insertions(+), 4 deletions(-)
rename clang/test/CodeGenHLSL/builtins/{asuint-splitdouble.hlsl => splitdouble.hlsl} (100%)
create mode 100644 llvm/test/CodeGen/DirectX/splitdouble.ll
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 53b60ad477a68b..459511b54666bf 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -19,7 +19,6 @@
#include "CGObjCRuntime.h"
#include "CGOpenMPRuntime.h"
#include "CGRecordLayout.h"
-#include "CGValue.h"
#include "CodeGenFunction.h"
#include "CodeGenModule.h"
#include "ConstantEmitter.h"
@@ -29,7 +28,6 @@
#include "clang/AST/DeclObjC.h"
#include "clang/AST/NSAPI.h"
#include "clang/AST/StmtVisitor.h"
-#include "clang/AST/Type.h"
#include "clang/Basic/Builtins.h"
#include "clang/Basic/CodeGenOptions.h"
#include "clang/Basic/SourceManager.h"
@@ -54,7 +52,6 @@
#include <optional>
#include <string>
-#include <utility>
using namespace clang;
using namespace CodeGen;
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
similarity index 100%
rename from clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
rename to clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 83c6b7f6d503dc..81463b19b22062 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -268,7 +268,7 @@ class OpLowerer {
Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
IRBuilder<> &IRB = OpBuilder.getIRB();
- for (Use &U : Intrin->uses()) {
+ for (Use &U : make_early_inc_range(Intrin->uses())) {
if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
assert(EVI->getNumIndices() == 1 &&
diff --git a/llvm/test/CodeGen/DirectX/splitdouble.ll b/llvm/test/CodeGen/DirectX/splitdouble.ll
new file mode 100644
index 00000000000000..3ada8c07325431
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/splitdouble.ll
@@ -0,0 +1,63 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; ModuleID = '../clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl'
+source_filename = "../clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl"
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxilv1.3-pc-shadermodel6.3-library"
+
+; Function Attrs: alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none)
+define noundef float @"?test_scalar@@YAMN at Z"(double noundef %D) local_unnamed_addr #0 {
+entry:
+ ; CHECK: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double %D)
+ ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
+ ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
+ %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
+ %0 = extractvalue { i32, i32 } %hlsl.asuint, 0
+ %1 = extractvalue { i32, i32 } %hlsl.asuint, 1
+ %add = add i32 %0, %1
+ %conv = uitofp i32 %add to float
+ ret float %conv
+}
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare { i32, i32 } @llvm.dx.splitdouble.i32(double) #1
+
+; Function Attrs: alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none)
+define noundef <3 x float> @"?test_vector@@YAT?$__vector at M$02 at __clang@@T?$__vector at N$02 at 2@@Z"(<3 x double> noundef %D) local_unnamed_addr #0 {
+entry:
+ %0 = extractelement <3 x double> %D, i64 0
+ ; CHECK-COUNT-3: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double {{.*}})
+ ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
+ ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
+ %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %0)
+ %1 = extractvalue { i32, i32 } %hlsl.asuint, 0
+ %2 = extractvalue { i32, i32 } %hlsl.asuint, 1
+ %3 = insertelement <3 x i32> poison, i32 %1, i64 0
+ %4 = insertelement <3 x i32> poison, i32 %2, i64 0
+ %5 = extractelement <3 x double> %D, i64 1
+ %hlsl.asuint2 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %5)
+ %6 = extractvalue { i32, i32 } %hlsl.asuint2, 0
+ %7 = extractvalue { i32, i32 } %hlsl.asuint2, 1
+ %8 = insertelement <3 x i32> %3, i32 %6, i64 1
+ %9 = insertelement <3 x i32> %4, i32 %7, i64 1
+ %10 = extractelement <3 x double> %D, i64 2
+ %hlsl.asuint3 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %10)
+ %11 = extractvalue { i32, i32 } %hlsl.asuint3, 0
+ %12 = extractvalue { i32, i32 } %hlsl.asuint3, 1
+ %13 = insertelement <3 x i32> %8, i32 %11, i64 2
+ %14 = insertelement <3 x i32> %9, i32 %12, i64 2
+ %add = add <3 x i32> %13, %14
+ %conv = uitofp <3 x i32> %add to <3 x float>
+ ret <3 x float> %conv
+}
+
+attributes #0 = { alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none) "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { mustprogress nocallback nofree nosync nounwind willreturn memory(none) }
+
+!llvm.module.flags = !{!0}
+!dx.valver = !{!1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 1, i32 8}
+!2 = !{!"clang version 20.0.0git (https://github.com/joaosaffran/llvm-project.git 81476c7ad27010600dc4b4be1d66e7c7db7c10fb)"}
>From 15958f4bd4623f380615a320ec09bb607eed52d3 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 26 Sep 2024 05:32:19 +0000
Subject: [PATCH 5/5] adding SPIRV
---
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 29 +++++++++++++++++++
.../CodeGenHLSL/builtins/splitdouble.hlsl | 9 ++++++
2 files changed, 38 insertions(+)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index dede9583d1bc58..78f1e98e963931 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -440,6 +440,35 @@ _HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
void asuint(double3, out uint3, out uint3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
void asuint(double4, out uint4, out uint4);
+
+#elif __is_target_arch(spirv)
+
+void asuint(double4 D, out uint4 lowbits, out uint4 highbits) {
+ uint4 bottom = __detail::bit_cast<uint4>(D.xy);
+ uint4 top = __detail::bit_cast<uint4>(D.zw);
+ lowbits = uint4(bottom.x, bottom.z, top.x, top.z);
+ highbits = uint4(bottom.y, bottom.w, top.y, top.w);
+}
+
+void asuint(double3 D, out uint3 lowbits, out uint3 highbits) {
+ uint4 bottom = __detail::bit_cast<uint4>(D.xy);
+ uint2 top = __detail::bit_cast<uint2>(D.z);
+ lowbits = uint3(bottom.x, bottom.z, top.x);
+ highbits = uint3(bottom.y, bottom.w, top.y);
+}
+
+void asuint(double2 D, out uint2 lowbits, out uint2 highbits) {
+ uint4 bottom = __detail::bit_cast<uint4>(D.xy);
+ lowbits = uint2(bottom.x, bottom.z);
+ highbits = uint2(bottom.y, bottom.w);
+}
+
+void asuint(double D, out uint lowbits, out uint highbits) {
+ uint2 bottom = __detail::bit_cast<uint2>(D);
+ lowbits = uint(bottom.x);
+ highbits = uint(bottom.y);
+}
+
#endif
//===----------------------------------------------------------------------===//
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
index 1711c344792aee..8febc500d3c2b9 100644
--- a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -1,10 +1,15 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv--vulkan-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefix=SPIRV
// CHECK: define {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
// CHECK: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble
+// SPIRV: [[REG:%.*]] = load double, ptr [[VALD]].addr
+// SPIRV: call spir_func void {{.*}}asuint{{.*}}(double {{.*}} [[REG]], {{.*}})
float test_scalar(double D) {
uint A, B;
asuint(D, A, B);
@@ -16,6 +21,10 @@ float test_scalar(double D) {
// CHECK-NEXT: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALREG]])
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble
+// SPIRV: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr
+// SPIRV: call spir_func void {{.*}}asuint{{.*}}(<3 x double> {{.*}} [[REG]], {{.*}})
float3 test_vector(double3 D) {
uint3 A, B;
asuint(D, A, B);
More information about the cfe-commits
mailing list