[clang] [llvm] Adding splitdouble HLSL function (PR #109331)
via cfe-commits
cfe-commits at lists.llvm.org
Mon Oct 21 15:46:03 PDT 2024
https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/109331
>From 060a0ab8d88f39bbabfda1bcd9e69f1e28adde87 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 19 Sep 2024 00:13:51 +0000
Subject: [PATCH 01/16] Codegen builtin
---
clang/include/clang/Basic/Builtins.td | 6 ++
clang/lib/CodeGen/CGBuiltin.cpp | 38 ++++++++++++
clang/lib/CodeGen/CGCall.cpp | 5 ++
clang/lib/CodeGen/CGExpr.cpp | 15 ++++-
clang/lib/CodeGen/CodeGenFunction.h | 10 +++-
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 20 +++++++
clang/lib/Sema/SemaHLSL.cpp | 58 ++++++++++++++++---
.../builtins/asuint-splitdouble.hlsl | 10 ++++
llvm/include/llvm/IR/IntrinsicsDirectX.td | 5 ++
llvm/lib/Target/DirectX/DXIL.td | 1 +
.../Target/DirectX/DXILIntrinsicExpansion.cpp | 1 +
11 files changed, 155 insertions(+), 14 deletions(-)
create mode 100644 clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..1bff7e6838836c 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4871,6 +4871,12 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
+def HLSLAsUintSplitDouble: LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_asuint_splitdouble"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 28f28c70b5ae52..000d2f73151747 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18952,6 +18952,44 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
nullptr, "hlsl.radians");
}
+ // This should only be called when targeting DXIL
+ case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+
+ assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
+ E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
+ E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
+ "asuint operands types mismatch");
+
+ Value *Op0 = EmitScalarExpr(E->getArg(0));
+ const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
+ const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
+
+ CallArgList Args;
+ LValue Op1TmpLValue = EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+ LValue Op2TmpLValue = EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+
+ llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty);
+ if (Op0->getType()->isVectorTy()) {
+ auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
+
+ llvm::VectorType *i32VecTy = llvm::VectorType::get(
+ Int32Ty, ElementCount::getFixed(XVecTy->getNumElements()));
+
+ retType = llvm::StructType::get(i32VecTy, i32VecTy);
+ }
+
+ CallInst *CI =
+ Builder.CreateIntrinsic(retType, llvm::Intrinsic::dx_asuint_splitdouble,
+ {Op0}, nullptr, "hlsl.asuint");
+
+ Value *arg0 = Builder.CreateExtractValue(CI, 0);
+ Value *arg1 = Builder.CreateExtractValue(CI, 1);
+
+ Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+ auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
+ EmitWritebacks(*this, Args);
+ return s;
+ }
}
return nullptr;
}
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 4ae981e4013e9c..096bbafa4cc694 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4681,6 +4681,11 @@ void CallArg::copyInto(CodeGenFunction &CGF, Address Addr) const {
IsUsed = true;
}
+void CodeGenFunction::EmitWritebacks(CodeGenFunction &CGF,
+ const CallArgList &args) {
+ emitWritebacks(CGF, args);
+}
+
void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
QualType type) {
DisableDebugLocationUpdates Dis(*this, E);
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 1e8ffb53b53a09..13ca5101f3f903 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -19,6 +19,7 @@
#include "CGObjCRuntime.h"
#include "CGOpenMPRuntime.h"
#include "CGRecordLayout.h"
+#include "CGValue.h"
#include "CodeGenFunction.h"
#include "CodeGenModule.h"
#include "ConstantEmitter.h"
@@ -28,6 +29,7 @@
#include "clang/AST/DeclObjC.h"
#include "clang/AST/NSAPI.h"
#include "clang/AST/StmtVisitor.h"
+#include "clang/AST/Type.h"
#include "clang/Basic/Builtins.h"
#include "clang/Basic/CodeGenOptions.h"
#include "clang/Basic/SourceManager.h"
@@ -5460,9 +5462,8 @@ LValue CodeGenFunction::EmitOpaqueValueLValue(const OpaqueValueExpr *e) {
return getOrCreateOpaqueLValueMapping(e);
}
-void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
- CallArgList &Args, QualType Ty) {
-
+std::pair<LValue, LValue>
+CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) {
// Emitting the casted temporary through an opaque value.
LValue BaseLV = EmitLValue(E->getArgLValue());
OpaqueValueMappingData::bind(*this, E->getOpaqueArgLValue(), BaseLV);
@@ -5476,6 +5477,13 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
TempLV);
OpaqueValueMappingData::bind(*this, E->getCastedTemporary(), TempLV);
+ return std::make_pair(BaseLV, TempLV);
+}
+
+LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
+ CallArgList &Args, QualType Ty) {
+
+ auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty);
llvm::Value *Addr = TempLV.getAddress().getBasePointer();
llvm::Type *ElTy = ConvertTypeForMem(TempLV.getType());
@@ -5488,6 +5496,7 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(),
LifetimeSize);
Args.add(RValue::get(TmpAddr, *this), Ty);
+ return TempLV;
}
LValue
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 5f203fe0b128b5..dd66144e5d5f57 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4298,8 +4298,11 @@ class CodeGenFunction : public CodeGenTypeCache {
LValue EmitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *E);
LValue EmitOpaqueValueLValue(const OpaqueValueExpr *e);
LValue EmitHLSLArrayAssignLValue(const BinaryOperator *E);
- void EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
- QualType Ty);
+
+ std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E,
+ QualType Ty);
+ LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
+ QualType Ty);
Address EmitExtVectorElementLValue(LValue V);
@@ -5149,6 +5152,9 @@ class CodeGenFunction : public CodeGenTypeCache {
SourceLocation ArgLoc, AbstractCallee AC,
unsigned ParmNum);
+ /// EmitWriteback - Emit callbacks for function.
+ void EmitWritebacks(CodeGenFunction &CGF, const CallArgList &args);
+
/// EmitCallArg - Emit a single call argument.
void EmitCallArg(CallArgList &args, const Expr *E, QualType ArgType);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..406d5e888fdd64 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -438,6 +438,26 @@ template <typename T> constexpr uint asuint(T F) {
return __detail::bit_cast<uint, T>(F);
}
+//===----------------------------------------------------------------------===//
+// asuint splitdouble builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn void asuint(double D, out uint lowbits, out int highbits)
+/// \brief Split and interprets the lowbits and highbits of double D into uints.
+/// \param D The input double.
+/// \param lowbits The output lowbits of D.
+/// \param highbits The highbits lowbits D.
+#if __is_target_arch(dxil)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double, out uint, out uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double2, out uint2, out uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double3, out uint3, out uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double4, out uint4, out uint4);
+#endif
+
//===----------------------------------------------------------------------===//
// atan builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index c6627b0e993226..e170fb1658b61b 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1698,18 +1698,27 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
return true;
}
-static bool CheckArgsTypesAreCorrect(
+bool CheckArgTypeIsCorrect(
+ Sema *S, Expr *Arg, QualType ExpectedType,
+ llvm::function_ref<bool(clang::QualType PassedType)> Check) {
+ QualType PassedType = Arg->getType();
+ if (Check(PassedType)) {
+ if (auto *VecTyA = PassedType->getAs<VectorType>())
+ ExpectedType = S->Context.getVectorType(
+ ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
+ S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
+ << PassedType << ExpectedType << 1 << 0 << 0;
+ return true;
+ }
+ return false;
+}
+
+bool CheckArgsTypesAreCorrect(
Sema *S, CallExpr *TheCall, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
- QualType PassedType = TheCall->getArg(i)->getType();
- if (Check(PassedType)) {
- if (auto *VecTyA = PassedType->getAs<VectorType>())
- ExpectedType = S->Context.getVectorType(
- ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
- S->Diag(TheCall->getArg(0)->getBeginLoc(),
- diag::err_typecheck_convert_incompatible)
- << PassedType << ExpectedType << 1 << 0 << 0;
+ Expr *Arg = TheCall->getArg(i);
+ if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
return true;
}
}
@@ -2074,6 +2083,37 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+ if (SemaRef.checkArgCount(TheCall, 3))
+ return true;
+
+ // Expr *Op0 = TheCall->getArg(0);
+
+ // auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool {
+ // return !PassedType->isDoubleType();
+ // };
+
+ // if (CheckArgTypeIsCorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
+ // CheckIsNotDouble)) {
+ // return true;
+ // }
+
+ // Expr *Op1 = TheCall->getArg(1);
+ // Expr *Op2 = TheCall->getArg(2);
+
+ // auto CheckIsNotUint = [](clang::QualType PassedType) -> bool {
+ // return !PassedType->isUnsignedIntegerType();
+ // };
+
+ // if (CheckArgTypeIsCorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
+ // CheckIsNotUint) ||
+ // CheckArgTypeIsCorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
+ // CheckIsNotUint)) {
+ // return true;
+ // }
+
+ break;
+ }
case Builtin::BI__builtin_elementwise_acos:
case Builtin::BI__builtin_elementwise_asin:
case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
new file mode 100644
index 00000000000000..e359354dc3a6df
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
@@ -0,0 +1,10 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O0 -o - | FileCheck %s
+
+// CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}}
+// CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}}
+// CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]])
+float fn(double D) {
+ uint A, B;
+ asuint(D, A, B);
+ return A + B;
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 27a437a83be6dd..6d239325d4b360 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -90,4 +90,9 @@ def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+
+def int_dx_asuint_splitdouble : DefaultAttrsIntrinsic<
+ [llvm_anyint_ty, LLVMMatchType<0>],
+ [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>],
+ [IntrNoMem, IntrWillReturn]>;
}
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..019f0b36cf4ff0 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -778,6 +778,7 @@ def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> {
let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
+//
def AnnotateHandle : DXILOp<217, annotateHandle> {
let Doc = "annotate handle with resource properties";
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index fb5383b3514a5a..fc2755dfb24252 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -12,6 +12,7 @@
#include "DXILIntrinsicExpansion.h"
#include "DirectX.h"
+#include "llvm-c/Core.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/DXILResource.h"
>From 297f76956c38d34acc91b6e090594a3515512eed Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Mon, 23 Sep 2024 21:19:12 +0000
Subject: [PATCH 02/16] adding vector case for splitdouble
---
clang/lib/CodeGen/CGBuiltin.cpp | 62 ++++++++++++++-----
clang/lib/CodeGen/CGExpr.cpp | 8 ++-
clang/lib/CodeGen/CodeGenFunction.h | 4 +-
.../builtins/asuint-splitdouble.hlsl | 4 +-
4 files changed, 57 insertions(+), 21 deletions(-)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 000d2f73151747..f974385191bfaa 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -34,12 +34,14 @@
#include "clang/Frontend/FrontendDiagnostic.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
@@ -67,6 +69,7 @@
#include "llvm/TargetParser/X86TargetParser.h"
#include <optional>
#include <sstream>
+#include <utility>
using namespace clang;
using namespace CodeGen;
@@ -18964,29 +18967,60 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
+ auto emitSplitDouble =
+ [](CGBuilderTy *Builder, llvm::Value *arg,
+ llvm::Type *retType) -> std::pair<Value *, Value *> {
+ CallInst *CI = Builder->CreateIntrinsic(
+ retType, llvm::Intrinsic::dx_asuint_splitdouble, {arg}, nullptr,
+ "hlsl.asuint");
+
+ Value *arg0 = Builder->CreateExtractValue(CI, 0);
+ Value *arg1 = Builder->CreateExtractValue(CI, 1);
+
+ return std::make_pair(arg0, arg1);
+ };
+
CallArgList Args;
- LValue Op1TmpLValue = EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
- LValue Op2TmpLValue = EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+ auto [Op1BaseLValue, Op1TmpLValue] =
+ EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+ auto [Op2BaseLValue, Op2TmpLValue] =
+ EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty);
- if (Op0->getType()->isVectorTy()) {
- auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
- llvm::VectorType *i32VecTy = llvm::VectorType::get(
- Int32Ty, ElementCount::getFixed(XVecTy->getNumElements()));
+ if (!Op0->getType()->isVectorTy()) {
+ auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType);
+
+ Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+ auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
- retType = llvm::StructType::get(i32VecTy, i32VecTy);
+ EmitWritebacks(*this, Args);
+ return s;
}
- CallInst *CI =
- Builder.CreateIntrinsic(retType, llvm::Intrinsic::dx_asuint_splitdouble,
- {Op0}, nullptr, "hlsl.asuint");
+ auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
+
+ llvm::VectorType *i32VecTy = llvm::VectorType::get(
+ Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
- Value *arg0 = Builder.CreateExtractValue(CI, 0);
- Value *arg1 = Builder.CreateExtractValue(CI, 1);
+ std::pair<Value *, Value *> inserts = std::make_pair(nullptr, nullptr);
+
+ for (uint64_t idx = 0; idx < Op0VecTy->getNumElements(); idx++) {
+ Value *op = Builder.CreateExtractElement(Op0, idx);
+
+ auto [arg0, arg1] = emitSplitDouble(&Builder, op, retType);
+
+ if (idx == 0) {
+ inserts.first = Builder.CreateInsertElement(i32VecTy, arg0, idx);
+ inserts.second = Builder.CreateInsertElement(i32VecTy, arg1, idx);
+ } else {
+ inserts.first = Builder.CreateInsertElement(inserts.first, arg0, idx);
+ inserts.second = Builder.CreateInsertElement(inserts.second, arg0, idx);
+ }
+ }
- Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
- auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
+ Builder.CreateStore(inserts.first, Op1TmpLValue.getAddress());
+ auto *s = Builder.CreateStore(inserts.second, Op2TmpLValue.getAddress());
EmitWritebacks(*this, Args);
return s;
}
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 13ca5101f3f903..14f1c5ef5b6df3 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -54,6 +54,7 @@
#include <optional>
#include <string>
+#include <utility>
using namespace clang;
using namespace CodeGen;
@@ -5480,8 +5481,9 @@ CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) {
return std::make_pair(BaseLV, TempLV);
}
-LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
- CallArgList &Args, QualType Ty) {
+std::pair<LValue, LValue>
+CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
+ QualType Ty) {
auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty);
@@ -5496,7 +5498,7 @@ LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(),
LifetimeSize);
Args.add(RValue::get(TmpAddr, *this), Ty);
- return TempLV;
+ return std::make_pair(BaseLV, TempLV);
}
LValue
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index dd66144e5d5f57..2b48d55a98c9a0 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4301,8 +4301,8 @@ class CodeGenFunction : public CodeGenTypeCache {
std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E,
QualType Ty);
- LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
- QualType Ty);
+ std::pair<LValue, LValue> EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
+ CallArgList &Args, QualType Ty);
Address EmitExtVectorElementLValue(LValue V);
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
index e359354dc3a6df..4326612db96b0f 100644
--- a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
@@ -3,8 +3,8 @@
// CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}}
// CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}}
// CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]])
-float fn(double D) {
- uint A, B;
+float2 fn(double2 D) {
+ uint2 A, B;
asuint(D, A, B);
return A + B;
}
>From 17838a79b1592cde7cf504cfd573adbd320c771f Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Tue, 24 Sep 2024 00:50:10 +0000
Subject: [PATCH 03/16] adding lowering to dxil
---
clang/include/clang/Basic/Builtins.td | 16 +++++-
clang/lib/CodeGen/CGBuiltin.cpp | 15 +++---
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 8 +--
clang/lib/Sema/SemaHLSL.cpp | 44 ++++++++--------
.../builtins/asuint-splitdouble.hlsl | 25 +++++++---
.../test/SemaHLSL/BuiltIns/asuint-errors.hlsl | 4 ++
llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 +-
llvm/lib/Target/DirectX/DXIL.td | 11 +++-
.../Target/DirectX/DXILIntrinsicExpansion.cpp | 1 -
llvm/lib/Target/DirectX/DXILOpBuilder.cpp | 13 +++++
llvm/lib/Target/DirectX/DXILOpBuilder.h | 4 ++
llvm/lib/Target/DirectX/DXILOpLowering.cpp | 50 +++++++++++++++++++
12 files changed, 147 insertions(+), 46 deletions(-)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 1bff7e6838836c..b35a24eda6da44 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4871,8 +4871,20 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
-def HLSLAsUintSplitDouble: LangBuiltin<"HLSL_LANG"> {
- let Spellings = ["__builtin_hlsl_asuint_splitdouble"];
+def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_elementwise_radians"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
+def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_elementwise_radians"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void(...)";
+}
+
+def HLSLSplitDouble: LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_splitdouble"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index f974385191bfaa..168ee508446d3b 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -34,14 +34,12 @@
#include "clang/Frontend/FrontendDiagnostic.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/InlineAsm.h"
-#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
@@ -69,7 +67,6 @@
#include "llvm/TargetParser/X86TargetParser.h"
#include <optional>
#include <sstream>
-#include <utility>
using namespace clang;
using namespace CodeGen;
@@ -18956,7 +18953,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
nullptr, "hlsl.radians");
}
// This should only be called when targeting DXIL
- case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+ case Builtin::BI__builtin_hlsl_splitdouble: {
assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
@@ -18970,9 +18967,9 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
auto emitSplitDouble =
[](CGBuilderTy *Builder, llvm::Value *arg,
llvm::Type *retType) -> std::pair<Value *, Value *> {
- CallInst *CI = Builder->CreateIntrinsic(
- retType, llvm::Intrinsic::dx_asuint_splitdouble, {arg}, nullptr,
- "hlsl.asuint");
+ CallInst *CI =
+ Builder->CreateIntrinsic(retType, llvm::Intrinsic::dx_splitdouble,
+ {arg}, nullptr, "hlsl.asuint");
Value *arg0 = Builder->CreateExtractValue(CI, 0);
Value *arg1 = Builder->CreateExtractValue(CI, 1);
@@ -18986,7 +18983,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
auto [Op2BaseLValue, Op2TmpLValue] =
EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
- llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty);
+ llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
if (!Op0->getType()->isVectorTy()) {
auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType);
@@ -19015,7 +19012,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
inserts.second = Builder.CreateInsertElement(i32VecTy, arg1, idx);
} else {
inserts.first = Builder.CreateInsertElement(inserts.first, arg0, idx);
- inserts.second = Builder.CreateInsertElement(inserts.second, arg0, idx);
+ inserts.second = Builder.CreateInsertElement(inserts.second, arg1, idx);
}
}
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 406d5e888fdd64..ef59ac777796f5 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -448,13 +448,13 @@ template <typename T> constexpr uint asuint(T F) {
/// \param lowbits The output lowbits of D.
/// \param highbits The highbits lowbits D.
#if __is_target_arch(dxil)
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
void asuint(double, out uint, out uint);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
void asuint(double2, out uint2, out uint2);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
void asuint(double3, out uint3, out uint3);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
void asuint(double4, out uint4, out uint4);
#endif
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index e170fb1658b61b..4f7209ca7bd0a7 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1698,7 +1698,7 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
return true;
}
-bool CheckArgTypeIsCorrect(
+bool CheckArgTypeIsIncorrect(
Sema *S, Expr *Arg, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
QualType PassedType = Arg->getType();
@@ -1718,7 +1718,7 @@ bool CheckArgsTypesAreCorrect(
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
Expr *Arg = TheCall->getArg(i);
- if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
+ if (CheckArgTypeIsIncorrect(S, Arg, ExpectedType, Check)) {
return true;
}
}
@@ -2083,34 +2083,34 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
- case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+ case Builtin::BI__builtin_hlsl_splitdouble: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;
- // Expr *Op0 = TheCall->getArg(0);
+ Expr *Op0 = TheCall->getArg(0);
- // auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool {
- // return !PassedType->isDoubleType();
- // };
+ auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool {
+ return !PassedType->hasFloatingRepresentation();
+ };
- // if (CheckArgTypeIsCorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
- // CheckIsNotDouble)) {
- // return true;
- // }
+ if (CheckArgTypeIsIncorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
+ CheckIsNotDouble)) {
+ return true;
+ }
- // Expr *Op1 = TheCall->getArg(1);
- // Expr *Op2 = TheCall->getArg(2);
+ Expr *Op1 = TheCall->getArg(1);
+ Expr *Op2 = TheCall->getArg(2);
- // auto CheckIsNotUint = [](clang::QualType PassedType) -> bool {
- // return !PassedType->isUnsignedIntegerType();
- // };
+ auto CheckIsNotUint = [](clang::QualType PassedType) -> bool {
+ return !PassedType->hasUnsignedIntegerRepresentation();
+ };
- // if (CheckArgTypeIsCorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
- // CheckIsNotUint) ||
- // CheckArgTypeIsCorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
- // CheckIsNotUint)) {
- // return true;
- // }
+ if (CheckArgTypeIsIncorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
+ CheckIsNotUint) ||
+ CheckArgTypeIsIncorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
+ CheckIsNotUint)) {
+ return true;
+ }
break;
}
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
index 4326612db96b0f..1711c344792aee 100644
--- a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
@@ -1,10 +1,23 @@
-// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O0 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
-// CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}}
-// CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}}
-// CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]])
-float2 fn(double2 D) {
- uint2 A, B;
+
+// CHECK: define {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+float test_scalar(double D) {
+ uint A, B;
+ asuint(D, A, B);
+ return A + B;
+}
+
+// CHECK: define {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[VALD]], i64 [[VALIDX:[0-3]]]
+// CHECK-NEXT: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALREG]])
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+float3 test_vector(double3 D) {
+ uint3 A, B;
asuint(D, A, B);
return A + B;
}
diff --git a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
index 8c56fdddb1c24c..b9a920f9f1b4d0 100644
--- a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl
@@ -6,6 +6,10 @@ uint4 test_asuint_too_many_arg(float p0, float p1) {
// expected-error at -1 {{no matching function for call to 'asuint'}}
// expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires single argument 'V', but 2 arguments were provided}}
// expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires single argument 'F', but 2 arguments were provided}}
+ // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
+ // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
+ // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
+ // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}}
}
uint test_asuint_double(double p1) {
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 6d239325d4b360..34ac423e278bba 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -91,7 +91,7 @@ def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
-def int_dx_asuint_splitdouble : DefaultAttrsIntrinsic<
+def int_dx_splitdouble : DefaultAttrsIntrinsic<
[llvm_anyint_ty, LLVMMatchType<0>],
[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>],
[IntrNoMem, IntrWillReturn]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 019f0b36cf4ff0..338cc546348b8d 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -47,6 +47,7 @@ def ResRetInt32Ty : DXILOpParamType;
def HandleTy : DXILOpParamType;
def ResBindTy : DXILOpParamType;
def ResPropsTy : DXILOpParamType;
+def ResSplitDoubleTy : DXILOpParamType;
class DXILOpClass;
@@ -778,7 +779,15 @@ def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> {
let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
-//
+
+def SplitDouble : DXILOp<102, splitDouble> {
+ let Doc = "Splits a double into 2 uints";
+ let arguments = [OverloadTy];
+ let result = ResSplitDoubleTy;
+ let overloads = [Overloads<DXIL1_0, [DoubleTy]>];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
def AnnotateHandle : DXILOp<217, annotateHandle> {
let Doc = "annotate handle with resource properties";
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index fc2755dfb24252..fb5383b3514a5a 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -12,7 +12,6 @@
#include "DXILIntrinsicExpansion.h"
#include "DirectX.h"
-#include "llvm-c/Core.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/DXILResource.h"
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 7719d6b1079110..982d7849d9bb8b 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -229,6 +229,13 @@ static StructType *getResPropsType(LLVMContext &Context) {
return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties");
}
+static StructType *getResSplitDoubleType(LLVMContext &Context) {
+ if (auto *ST = StructType::getTypeByName(Context, "dx.types.splitdouble"))
+ return ST;
+ Type *Int32Ty = Type::getInt32Ty(Context);
+ return StructType::create({Int32Ty, Int32Ty}, "dx.types.splitdouble");
+}
+
static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
Type *OverloadTy) {
switch (Kind) {
@@ -266,6 +273,8 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
return getResBindType(Ctx);
case OpParamType::ResPropsTy:
return getResPropsType(Ctx);
+ case OpParamType::ResSplitDoubleTy:
+ return getResSplitDoubleType(Ctx);
}
llvm_unreachable("Invalid parameter kind");
return nullptr;
@@ -467,6 +476,10 @@ StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
return ::getResRetType(ElementTy);
}
+StructType *DXILOpBuilder::getResSplitDoubleType(LLVMContext &Context) {
+ return ::getResSplitDoubleType(Context);
+}
+
StructType *DXILOpBuilder::getHandleType() {
return ::getHandleType(IRB.getContext());
}
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 037ae3822cfb90..8b1e87c283146c 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -49,6 +49,10 @@ class DXILOpBuilder {
/// Get a `%dx.types.ResRet` type with the given element type.
StructType *getResRetType(Type *ElementTy);
+
+ /// Get the `%dx.types.splitdouble` type.
+ StructType *getResSplitDoubleType(LLVMContext &Context);
+
/// Get the `%dx.types.Handle` type.
StructType *getHandleType();
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index c62ba8c21d6791..28e1cb2ce2e19f 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -17,6 +17,7 @@
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/Module.h"
@@ -263,6 +264,31 @@ class OpLowerer {
return lowerToBindAndAnnotateHandle(F);
}
+ Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
+ IRBuilder<> &IRB = OpBuilder.getIRB();
+
+ for (Use &U : Intrin->uses()) {
+ if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
+
+ assert(EVI->getNumIndices() == 1 &&
+ "splitdouble result should be indexed individually.");
+ if (EVI->getNumIndices() != 1)
+ return make_error<StringError>(
+ "splitdouble result should be indexed individually.",
+ inconvertibleErrorCode());
+
+ unsigned int IndexVal = EVI->getIndices()[0];
+
+ auto *OpEVI = IRB.CreateExtractValue(Op, IndexVal);
+ EVI->replaceAllUsesWith(OpEVI);
+ EVI->eraseFromParent();
+ }
+ }
+ Intrin->eraseFromParent();
+
+ return Error::success();
+ }
+
/// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
/// Since we expect to be post-scalarization, make an effort to avoid vectors.
Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
@@ -460,6 +486,27 @@ class OpLowerer {
});
}
+ [[nodiscard]] bool lowerSplitDouble(Function &F) {
+ IRBuilder<> &IRB = OpBuilder.getIRB();
+ return replaceFunction(F, [&](CallInst *CI) -> Error {
+ IRB.SetInsertPoint(CI);
+
+ Value *Arg0 = CI->getArgOperand(0);
+
+ Type *NewRetTy = OpBuilder.getResSplitDoubleType(M.getContext());
+
+ std::array<Value *, 1> Args{Arg0};
+ Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
+ OpCode::SplitDouble, Args, CI->getName(), NewRetTy);
+ if (Error E = OpCall.takeError())
+ return E;
+ if (Error E = replaceSplitDoubleCallUsages(CI, *OpCall))
+ return E;
+
+ return Error::success();
+ });
+ }
+
bool lowerIntrinsics() {
bool Updated = false;
bool HasErrors = false;
@@ -488,6 +535,9 @@ class OpLowerer {
case Intrinsic::dx_typedBufferStore:
HasErrors |= lowerTypedBufferStore(F);
break;
+ case Intrinsic::dx_splitdouble:
+ HasErrors |= lowerSplitDouble(F);
+ break;
}
Updated = true;
}
>From c92fe8b9d2d0b242d850f812b13b58c79df25062 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 26 Sep 2024 00:20:54 +0000
Subject: [PATCH 04/16] adding tests
---
clang/lib/CodeGen/CGExpr.cpp | 3 -
...uint-splitdouble.hlsl => splitdouble.hlsl} | 0
llvm/lib/Target/DirectX/DXILOpLowering.cpp | 2 +-
llvm/test/CodeGen/DirectX/splitdouble.ll | 63 +++++++++++++++++++
4 files changed, 64 insertions(+), 4 deletions(-)
rename clang/test/CodeGenHLSL/builtins/{asuint-splitdouble.hlsl => splitdouble.hlsl} (100%)
create mode 100644 llvm/test/CodeGen/DirectX/splitdouble.ll
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 14f1c5ef5b6df3..8757b08079dab6 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -19,7 +19,6 @@
#include "CGObjCRuntime.h"
#include "CGOpenMPRuntime.h"
#include "CGRecordLayout.h"
-#include "CGValue.h"
#include "CodeGenFunction.h"
#include "CodeGenModule.h"
#include "ConstantEmitter.h"
@@ -29,7 +28,6 @@
#include "clang/AST/DeclObjC.h"
#include "clang/AST/NSAPI.h"
#include "clang/AST/StmtVisitor.h"
-#include "clang/AST/Type.h"
#include "clang/Basic/Builtins.h"
#include "clang/Basic/CodeGenOptions.h"
#include "clang/Basic/SourceManager.h"
@@ -54,7 +52,6 @@
#include <optional>
#include <string>
-#include <utility>
using namespace clang;
using namespace CodeGen;
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
similarity index 100%
rename from clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
rename to clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 28e1cb2ce2e19f..2b62575ea77821 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -267,7 +267,7 @@ class OpLowerer {
Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
IRBuilder<> &IRB = OpBuilder.getIRB();
- for (Use &U : Intrin->uses()) {
+ for (Use &U : make_early_inc_range(Intrin->uses())) {
if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
assert(EVI->getNumIndices() == 1 &&
diff --git a/llvm/test/CodeGen/DirectX/splitdouble.ll b/llvm/test/CodeGen/DirectX/splitdouble.ll
new file mode 100644
index 00000000000000..3ada8c07325431
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/splitdouble.ll
@@ -0,0 +1,63 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+
+; ModuleID = '../clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl'
+source_filename = "../clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl"
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
+target triple = "dxilv1.3-pc-shadermodel6.3-library"
+
+; Function Attrs: alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none)
+define noundef float @"?test_scalar@@YAMN at Z"(double noundef %D) local_unnamed_addr #0 {
+entry:
+ ; CHECK: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double %D)
+ ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
+ ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
+ %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
+ %0 = extractvalue { i32, i32 } %hlsl.asuint, 0
+ %1 = extractvalue { i32, i32 } %hlsl.asuint, 1
+ %add = add i32 %0, %1
+ %conv = uitofp i32 %add to float
+ ret float %conv
+}
+
+; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none)
+declare { i32, i32 } @llvm.dx.splitdouble.i32(double) #1
+
+; Function Attrs: alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none)
+define noundef <3 x float> @"?test_vector@@YAT?$__vector at M$02 at __clang@@T?$__vector at N$02 at 2@@Z"(<3 x double> noundef %D) local_unnamed_addr #0 {
+entry:
+ %0 = extractelement <3 x double> %D, i64 0
+ ; CHECK-COUNT-3: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double {{.*}})
+ ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
+ ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
+ %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %0)
+ %1 = extractvalue { i32, i32 } %hlsl.asuint, 0
+ %2 = extractvalue { i32, i32 } %hlsl.asuint, 1
+ %3 = insertelement <3 x i32> poison, i32 %1, i64 0
+ %4 = insertelement <3 x i32> poison, i32 %2, i64 0
+ %5 = extractelement <3 x double> %D, i64 1
+ %hlsl.asuint2 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %5)
+ %6 = extractvalue { i32, i32 } %hlsl.asuint2, 0
+ %7 = extractvalue { i32, i32 } %hlsl.asuint2, 1
+ %8 = insertelement <3 x i32> %3, i32 %6, i64 1
+ %9 = insertelement <3 x i32> %4, i32 %7, i64 1
+ %10 = extractelement <3 x double> %D, i64 2
+ %hlsl.asuint3 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %10)
+ %11 = extractvalue { i32, i32 } %hlsl.asuint3, 0
+ %12 = extractvalue { i32, i32 } %hlsl.asuint3, 1
+ %13 = insertelement <3 x i32> %8, i32 %11, i64 2
+ %14 = insertelement <3 x i32> %9, i32 %12, i64 2
+ %add = add <3 x i32> %13, %14
+ %conv = uitofp <3 x i32> %add to <3 x float>
+ ret <3 x float> %conv
+}
+
+attributes #0 = { alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none) "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { mustprogress nocallback nofree nosync nounwind willreturn memory(none) }
+
+!llvm.module.flags = !{!0}
+!dx.valver = !{!1}
+!llvm.ident = !{!2}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 1, i32 8}
+!2 = !{!"clang version 20.0.0git (https://github.com/joaosaffran/llvm-project.git 81476c7ad27010600dc4b4be1d66e7c7db7c10fb)"}
>From a36f37517e2cf67a9b1c1bb03b96b27466dde2a6 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 26 Sep 2024 05:32:19 +0000
Subject: [PATCH 05/16] adding SPIRV
---
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 29 +++++++++++++++++++
.../CodeGenHLSL/builtins/splitdouble.hlsl | 9 ++++++
2 files changed, 38 insertions(+)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index ef59ac777796f5..0ad944b9e5d181 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -456,6 +456,35 @@ _HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
void asuint(double3, out uint3, out uint3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
void asuint(double4, out uint4, out uint4);
+
+#elif __is_target_arch(spirv)
+
+void asuint(double4 D, out uint4 lowbits, out uint4 highbits) {
+ uint4 bottom = __detail::bit_cast<uint4>(D.xy);
+ uint4 top = __detail::bit_cast<uint4>(D.zw);
+ lowbits = uint4(bottom.x, bottom.z, top.x, top.z);
+ highbits = uint4(bottom.y, bottom.w, top.y, top.w);
+}
+
+void asuint(double3 D, out uint3 lowbits, out uint3 highbits) {
+ uint4 bottom = __detail::bit_cast<uint4>(D.xy);
+ uint2 top = __detail::bit_cast<uint2>(D.z);
+ lowbits = uint3(bottom.x, bottom.z, top.x);
+ highbits = uint3(bottom.y, bottom.w, top.y);
+}
+
+void asuint(double2 D, out uint2 lowbits, out uint2 highbits) {
+ uint4 bottom = __detail::bit_cast<uint4>(D.xy);
+ lowbits = uint2(bottom.x, bottom.z);
+ highbits = uint2(bottom.y, bottom.w);
+}
+
+void asuint(double D, out uint lowbits, out uint highbits) {
+ uint2 bottom = __detail::bit_cast<uint2>(D);
+ lowbits = uint(bottom.x);
+ highbits = uint(bottom.y);
+}
+
#endif
//===----------------------------------------------------------------------===//
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
index 1711c344792aee..8febc500d3c2b9 100644
--- a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -1,10 +1,15 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv--vulkan-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefix=SPIRV
// CHECK: define {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
// CHECK: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble
+// SPIRV: [[REG:%.*]] = load double, ptr [[VALD]].addr
+// SPIRV: call spir_func void {{.*}}asuint{{.*}}(double {{.*}} [[REG]], {{.*}})
float test_scalar(double D) {
uint A, B;
asuint(D, A, B);
@@ -16,6 +21,10 @@ float test_scalar(double D) {
// CHECK-NEXT: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALREG]])
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble
+// SPIRV: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr
+// SPIRV: call spir_func void {{.*}}asuint{{.*}}(<3 x double> {{.*}} [[REG]], {{.*}})
float3 test_vector(double3 D) {
uint3 A, B;
asuint(D, A, B);
>From 8f4d4016ea96f16e381ef5567cb8ffb252b23076 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 26 Sep 2024 17:31:19 +0000
Subject: [PATCH 06/16] fixing hlsl-lang-targets-spirv.hlsl test
---
clang/test/Driver/hlsl-lang-targets-spirv.hlsl | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/clang/test/Driver/hlsl-lang-targets-spirv.hlsl b/clang/test/Driver/hlsl-lang-targets-spirv.hlsl
index 61b10e1648c52b..5928c948315f1e 100644
--- a/clang/test/Driver/hlsl-lang-targets-spirv.hlsl
+++ b/clang/test/Driver/hlsl-lang-targets-spirv.hlsl
@@ -3,12 +3,12 @@
// Supported targets
//
-// RUN: %clang -target dxil-unknown-shadermodel6.2-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -target spirv-unknown-vulkan-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -target spirv-unknown-vulkan1.2-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -target spirv-unknown-vulkan1.3-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -target spirv1.5-unknown-vulkan1.2-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -target spirv1.6-unknown-vulkan1.3-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -cc1 -triple dxil-unknown-shadermodel6.2-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -cc1 -triple spirv-unknown-vulkan-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -cc1 -triple spirv-unknown-vulkan1.2-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -cc1 -triple spirv-unknown-vulkan1.3-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -cc1 -triple spirv1.5-unknown-vulkan1.2-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -cc1 -triple spirv1.6-unknown-vulkan1.3-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
// Empty Vulkan environment
//
>From efc52dec7517ac95241188c6209ab3c6dac49f84 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Fri, 27 Sep 2024 21:06:27 +0000
Subject: [PATCH 07/16] fixing comments in test
---
llvm/test/CodeGen/DirectX/splitdouble.ll | 27 ++++++------------------
1 file changed, 7 insertions(+), 20 deletions(-)
diff --git a/llvm/test/CodeGen/DirectX/splitdouble.ll b/llvm/test/CodeGen/DirectX/splitdouble.ll
index 3ada8c07325431..bfd337042851bd 100644
--- a/llvm/test/CodeGen/DirectX/splitdouble.ll
+++ b/llvm/test/CodeGen/DirectX/splitdouble.ll
@@ -1,12 +1,11 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
-; ModuleID = '../clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl'
-source_filename = "../clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl"
-target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64"
-target triple = "dxilv1.3-pc-shadermodel6.3-library"
+; Make sure DXILOpLowering is correctly generating the dxil op code call.
-; Function Attrs: alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none)
-define noundef float @"?test_scalar@@YAMN at Z"(double noundef %D) local_unnamed_addr #0 {
+
+
+; CHECK-LABEL: define noundef float @test_scalar_double_split
+define noundef float @test_scalar_double_split(double noundef %D) local_unnamed_addr {
entry:
; CHECK: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double %D)
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
@@ -19,11 +18,10 @@ entry:
ret float %conv
}
-; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none)
declare { i32, i32 } @llvm.dx.splitdouble.i32(double) #1
-; Function Attrs: alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none)
-define noundef <3 x float> @"?test_vector@@YAT?$__vector at M$02 at __clang@@T?$__vector at N$02 at 2@@Z"(<3 x double> noundef %D) local_unnamed_addr #0 {
+; CHECK-LABEL: define noundef <3 x float> @test_vector_double_split
+define noundef <3 x float> @test_vector_double_split(<3 x double> noundef %D) local_unnamed_addr {
entry:
%0 = extractelement <3 x double> %D, i64 0
; CHECK-COUNT-3: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double {{.*}})
@@ -50,14 +48,3 @@ entry:
%conv = uitofp <3 x i32> %add to <3 x float>
ret <3 x float> %conv
}
-
-attributes #0 = { alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none) "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
-attributes #1 = { mustprogress nocallback nofree nosync nounwind willreturn memory(none) }
-
-!llvm.module.flags = !{!0}
-!dx.valver = !{!1}
-!llvm.ident = !{!2}
-
-!0 = !{i32 1, !"wchar_size", i32 4}
-!1 = !{i32 1, i32 8}
-!2 = !{!"clang version 20.0.0git (https://github.com/joaosaffran/llvm-project.git 81476c7ad27010600dc4b4be1d66e7c7db7c10fb)"}
>From f7c6b4b998ba53faf3b686166a78c1389d3daa7d Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Fri, 27 Sep 2024 23:25:47 +0000
Subject: [PATCH 08/16] changing intrinsic signature to return vector
---
clang/lib/CodeGen/CGBuiltin.cpp | 8 +++--
.../CodeGenHLSL/builtins/splitdouble.hlsl | 18 +++++-----
llvm/include/llvm/IR/IntrinsicsDirectX.td | 4 +--
llvm/lib/Target/DirectX/DXILOpLowering.cpp | 19 ++++-------
llvm/test/CodeGen/DirectX/splitdouble.ll | 34 +++++++++----------
5 files changed, 41 insertions(+), 42 deletions(-)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 168ee508446d3b..4e701ba8726f9f 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -39,6 +39,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsAArch64.h"
@@ -18971,8 +18972,8 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
Builder->CreateIntrinsic(retType, llvm::Intrinsic::dx_splitdouble,
{arg}, nullptr, "hlsl.asuint");
- Value *arg0 = Builder->CreateExtractValue(CI, 0);
- Value *arg1 = Builder->CreateExtractValue(CI, 1);
+ Value *arg0 = Builder->CreateExtractElement(CI, (uint64_t)0);
+ Value *arg1 = Builder->CreateExtractElement(CI, (uint64_t)1);
return std::make_pair(arg0, arg1);
};
@@ -18983,7 +18984,8 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
auto [Op2BaseLValue, Op2TmpLValue] =
EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
- llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
+ llvm::VectorType *retType =
+ llvm::VectorType::get(Int32Ty, ElementCount::getFixed(2));
if (!Op0->getType()->isVectorTy()) {
auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType);
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
index 8febc500d3c2b9..f9e2122f4587a1 100644
--- a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -1,11 +1,12 @@
-// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv--vulkan-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefix=SPIRV
// CHECK: define {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
-// CHECK: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
-// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
-// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+// CHECK: [[REG:%.*]] = load double, ptr [[VALD]].addr, align 8
+// CHECK-NEXT: [[VALRET:%.*]] = call <2 x i32> @llvm.dx.splitdouble.v2i32(double [[REG]])
+// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 0
+// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 1
// SPIRV: define spir_func {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
// SPIRV-NOT: @llvm.dx.splitdouble
// SPIRV: [[REG:%.*]] = load double, ptr [[VALD]].addr
@@ -17,10 +18,11 @@ float test_scalar(double D) {
}
// CHECK: define {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
-// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[VALD]], i64 [[VALIDX:[0-3]]]
-// CHECK-NEXT: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALREG]])
-// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
-// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+// CHECK: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr, align
+// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[REG]], i64 [[VALIDX:[0-3]]]
+// CHECK-NEXT: [[VALRET:%.*]] = call <2 x i32> @llvm.dx.splitdouble.v2i32(double [[VALREG]])
+// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 0
+// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 1
// SPIRV: define spir_func {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
// SPIRV-NOT: @llvm.dx.splitdouble
// SPIRV: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 34ac423e278bba..5719b59f75bd3b 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -92,7 +92,7 @@ def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, L
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_splitdouble : DefaultAttrsIntrinsic<
- [llvm_anyint_ty, LLVMMatchType<0>],
- [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>],
+ [llvm_anyvector_ty],
+ [llvm_double_ty],
[IntrNoMem, IntrWillReturn]>;
}
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 2b62575ea77821..0e3c35fe6a6147 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -268,20 +268,15 @@ class OpLowerer {
IRBuilder<> &IRB = OpBuilder.getIRB();
for (Use &U : make_early_inc_range(Intrin->uses())) {
- if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
-
- assert(EVI->getNumIndices() == 1 &&
- "splitdouble result should be indexed individually.");
- if (EVI->getNumIndices() != 1)
- return make_error<StringError>(
- "splitdouble result should be indexed individually.",
- inconvertibleErrorCode());
+ if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
+ if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
- unsigned int IndexVal = EVI->getIndices()[0];
+ size_t IndexVal = IndexOp->getZExtValue();
- auto *OpEVI = IRB.CreateExtractValue(Op, IndexVal);
- EVI->replaceAllUsesWith(OpEVI);
- EVI->eraseFromParent();
+ auto *OpEVI = IRB.CreateExtractValue(Op, IndexVal);
+ EEI->replaceAllUsesWith(OpEVI);
+ EEI->eraseFromParent();
+ }
}
}
Intrin->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/splitdouble.ll b/llvm/test/CodeGen/DirectX/splitdouble.ll
index bfd337042851bd..e1da2b2d4a9d66 100644
--- a/llvm/test/CodeGen/DirectX/splitdouble.ll
+++ b/llvm/test/CodeGen/DirectX/splitdouble.ll
@@ -1,8 +1,7 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S --scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
-; Make sure DXILOpLowering is correctly generating the dxil op code call.
-
-
+; Make sure DXILOpLowering is correctly generating the dxil op code call, with and without scalarizer.
; CHECK-LABEL: define noundef float @test_scalar_double_split
define noundef float @test_scalar_double_split(double noundef %D) local_unnamed_addr {
@@ -10,15 +9,16 @@ entry:
; CHECK: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double %D)
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
- %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
- %0 = extractvalue { i32, i32 } %hlsl.asuint, 0
- %1 = extractvalue { i32, i32 } %hlsl.asuint, 1
- %add = add i32 %0, %1
+ %hlsl.asuint = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %D)
+ %1 = extractelement <2 x i32> %hlsl.asuint, i64 0
+ %2 = extractelement <2 x i32> %hlsl.asuint, i64 1
+ %add = add i32 %1, %2
%conv = uitofp i32 %add to float
ret float %conv
}
-declare { i32, i32 } @llvm.dx.splitdouble.i32(double) #1
+declare <2 x i32> @llvm.dx.splitdouble.v2i32(double) #1
+
; CHECK-LABEL: define noundef <3 x float> @test_vector_double_split
define noundef <3 x float> @test_vector_double_split(<3 x double> noundef %D) local_unnamed_addr {
@@ -27,21 +27,21 @@ entry:
; CHECK-COUNT-3: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double {{.*}})
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
- %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %0)
- %1 = extractvalue { i32, i32 } %hlsl.asuint, 0
- %2 = extractvalue { i32, i32 } %hlsl.asuint, 1
+ %hlsl.asuint = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %0)
+ %1 = extractelement <2 x i32> %hlsl.asuint, i64 0
+ %2 = extractelement <2 x i32> %hlsl.asuint, i64 1
%3 = insertelement <3 x i32> poison, i32 %1, i64 0
%4 = insertelement <3 x i32> poison, i32 %2, i64 0
%5 = extractelement <3 x double> %D, i64 1
- %hlsl.asuint2 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %5)
- %6 = extractvalue { i32, i32 } %hlsl.asuint2, 0
- %7 = extractvalue { i32, i32 } %hlsl.asuint2, 1
+ %hlsl.asuint2 = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %5)
+ %6 = extractelement <2 x i32> %hlsl.asuint2, i64 0
+ %7 = extractelement <2 x i32> %hlsl.asuint2, i64 1
%8 = insertelement <3 x i32> %3, i32 %6, i64 1
%9 = insertelement <3 x i32> %4, i32 %7, i64 1
%10 = extractelement <3 x double> %D, i64 2
- %hlsl.asuint3 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %10)
- %11 = extractvalue { i32, i32 } %hlsl.asuint3, 0
- %12 = extractvalue { i32, i32 } %hlsl.asuint3, 1
+ %hlsl.asuint3 = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %10)
+ %11 = extractelement <2 x i32> %hlsl.asuint3, i64 0
+ %12 = extractelement <2 x i32> %hlsl.asuint3, i64 1
%13 = insertelement <3 x i32> %8, i32 %11, i64 2
%14 = insertelement <3 x i32> %9, i32 %12, i64 2
%add = add <3 x i32> %13, %14
>From bb2c9722f40fef198d81cfcfec5de44d05d87420 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Sat, 28 Sep 2024 06:19:42 +0000
Subject: [PATCH 09/16] pushing original changes
---
clang/lib/CodeGen/CGBuiltin.cpp | 7 +++--
.../CodeGenHLSL/builtins/splitdouble.hlsl | 20 +++++++-------
.../test/Driver/hlsl-lang-targets-spirv.hlsl | 12 ++++-----
llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 +-
llvm/lib/Target/DirectX/DXILOpLowering.cpp | 19 +++++++++-----
llvm/test/CodeGen/DirectX/splitdouble.ll | 26 +++++++++----------
6 files changed, 45 insertions(+), 41 deletions(-)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 4e701ba8726f9f..04a36bd6606152 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18972,8 +18972,8 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
Builder->CreateIntrinsic(retType, llvm::Intrinsic::dx_splitdouble,
{arg}, nullptr, "hlsl.asuint");
- Value *arg0 = Builder->CreateExtractElement(CI, (uint64_t)0);
- Value *arg1 = Builder->CreateExtractElement(CI, (uint64_t)1);
+ Value *arg0 = Builder->CreateExtractValue(CI, 0);
+ Value *arg1 = Builder->CreateExtractValue(CI, 1);
return std::make_pair(arg0, arg1);
};
@@ -18984,8 +18984,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
auto [Op2BaseLValue, Op2TmpLValue] =
EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
- llvm::VectorType *retType =
- llvm::VectorType::get(Int32Ty, ElementCount::getFixed(2));
+ llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
if (!Op0->getType()->isVectorTy()) {
auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType);
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
index f9e2122f4587a1..b937bb5d4d343d 100644
--- a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -1,12 +1,12 @@
-// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv--vulkan-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefix=SPIRV
+
// CHECK: define {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
-// CHECK: [[REG:%.*]] = load double, ptr [[VALD]].addr, align 8
-// CHECK-NEXT: [[VALRET:%.*]] = call <2 x i32> @llvm.dx.splitdouble.v2i32(double [[REG]])
-// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 0
-// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 1
+// CHECK: [[VALRET:%hlsl.asuint.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
// SPIRV: define spir_func {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
// SPIRV-NOT: @llvm.dx.splitdouble
// SPIRV: [[REG:%.*]] = load double, ptr [[VALD]].addr
@@ -17,12 +17,12 @@ float test_scalar(double D) {
return A + B;
}
+
// CHECK: define {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
-// CHECK: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr, align
-// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[REG]], i64 [[VALIDX:[0-3]]]
-// CHECK-NEXT: [[VALRET:%.*]] = call <2 x i32> @llvm.dx.splitdouble.v2i32(double [[VALREG]])
-// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 0
-// CHECK-NEXT: extractelement <2 x i32> [[VALRET]], i64 1
+// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[VALD]], i64 [[VALIDX:[0-3]]]
+// CHECK-NEXT: [[VALRET:%hlsl.asuint.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALREG]])
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
// SPIRV: define spir_func {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
// SPIRV-NOT: @llvm.dx.splitdouble
// SPIRV: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr
diff --git a/clang/test/Driver/hlsl-lang-targets-spirv.hlsl b/clang/test/Driver/hlsl-lang-targets-spirv.hlsl
index 5928c948315f1e..61b10e1648c52b 100644
--- a/clang/test/Driver/hlsl-lang-targets-spirv.hlsl
+++ b/clang/test/Driver/hlsl-lang-targets-spirv.hlsl
@@ -3,12 +3,12 @@
// Supported targets
//
-// RUN: %clang -cc1 -triple dxil-unknown-shadermodel6.2-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -cc1 -triple spirv-unknown-vulkan-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -cc1 -triple spirv-unknown-vulkan1.2-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -cc1 -triple spirv-unknown-vulkan1.3-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -cc1 -triple spirv1.5-unknown-vulkan1.2-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
-// RUN: %clang -cc1 -triple spirv1.6-unknown-vulkan1.3-compute %s -S -disable-llvm-passes -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -target dxil-unknown-shadermodel6.2-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -target spirv-unknown-vulkan-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -target spirv-unknown-vulkan1.2-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -target spirv-unknown-vulkan1.3-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -target spirv1.5-unknown-vulkan1.2-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
+// RUN: %clang -target spirv1.6-unknown-vulkan1.3-compute %s -S -o /dev/null 2>&1 | FileCheck --allow-empty --check-prefix=CHECK-VALID %s
// Empty Vulkan environment
//
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 5719b59f75bd3b..3a5d6349f85622 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -92,7 +92,7 @@ def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, L
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_splitdouble : DefaultAttrsIntrinsic<
- [llvm_anyvector_ty],
+ [llvm_anyint_ty, LLVMMatchType<0>],
[llvm_double_ty],
[IntrNoMem, IntrWillReturn]>;
}
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 0e3c35fe6a6147..340914ff2cf422 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -23,7 +23,9 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/InitializePasses.h"
+#include "llvm/Object/Error.h"
#include "llvm/Pass.h"
+#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "dxil-op-lower"
@@ -268,17 +270,20 @@ class OpLowerer {
IRBuilder<> &IRB = OpBuilder.getIRB();
for (Use &U : make_early_inc_range(Intrin->uses())) {
- if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
- if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
+ if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
- size_t IndexVal = IndexOp->getZExtValue();
+ if (EVI->getNumIndices() != 1)
+ return createStringError(std::errc::invalid_argument,
+ "Splitdouble has only 2 elements");
- auto *OpEVI = IRB.CreateExtractValue(Op, IndexVal);
- EEI->replaceAllUsesWith(OpEVI);
- EEI->eraseFromParent();
- }
+ size_t IndexVal = EVI->getIndices()[0];
+
+ auto *OpEVI = IRB.CreateExtractValue(Op, IndexVal);
+ EVI->replaceAllUsesWith(OpEVI);
+ EVI->eraseFromParent();
}
}
+
Intrin->eraseFromParent();
return Error::success();
diff --git a/llvm/test/CodeGen/DirectX/splitdouble.ll b/llvm/test/CodeGen/DirectX/splitdouble.ll
index e1da2b2d4a9d66..c62b7dd2371ba2 100644
--- a/llvm/test/CodeGen/DirectX/splitdouble.ll
+++ b/llvm/test/CodeGen/DirectX/splitdouble.ll
@@ -9,10 +9,10 @@ entry:
; CHECK: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double %D)
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
- %hlsl.asuint = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %D)
- %1 = extractelement <2 x i32> %hlsl.asuint, i64 0
- %2 = extractelement <2 x i32> %hlsl.asuint, i64 1
- %add = add i32 %1, %2
+ %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
+ %0 = extractvalue { i32, i32 } %hlsl.asuint, 0
+ %1 = extractvalue { i32, i32 } %hlsl.asuint, 1
+ %add = add i32 %0, %1
%conv = uitofp i32 %add to float
ret float %conv
}
@@ -27,21 +27,21 @@ entry:
; CHECK-COUNT-3: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double {{.*}})
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
- %hlsl.asuint = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %0)
- %1 = extractelement <2 x i32> %hlsl.asuint, i64 0
- %2 = extractelement <2 x i32> %hlsl.asuint, i64 1
+ %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %0)
+ %1 = extractvalue { i32, i32 } %hlsl.asuint, 0
+ %2 = extractvalue { i32, i32 } %hlsl.asuint, 1
%3 = insertelement <3 x i32> poison, i32 %1, i64 0
%4 = insertelement <3 x i32> poison, i32 %2, i64 0
%5 = extractelement <3 x double> %D, i64 1
- %hlsl.asuint2 = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %5)
- %6 = extractelement <2 x i32> %hlsl.asuint2, i64 0
- %7 = extractelement <2 x i32> %hlsl.asuint2, i64 1
+ %hlsl.asuint2 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %5)
+ %6 = extractvalue { i32, i32 } %hlsl.asuint2, 0
+ %7 = extractvalue { i32, i32 } %hlsl.asuint2, 1
%8 = insertelement <3 x i32> %3, i32 %6, i64 1
%9 = insertelement <3 x i32> %4, i32 %7, i64 1
%10 = extractelement <3 x double> %D, i64 2
- %hlsl.asuint3 = call <2 x i32> @llvm.dx.splitdouble.v2i32(double %10)
- %11 = extractelement <2 x i32> %hlsl.asuint3, i64 0
- %12 = extractelement <2 x i32> %hlsl.asuint3, i64 1
+ %hlsl.asuint3 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %10)
+ %11 = extractvalue { i32, i32 } %hlsl.asuint3, 0
+ %12 = extractvalue { i32, i32 } %hlsl.asuint3, 1
%13 = insertelement <3 x i32> %8, i32 %11, i64 2
%14 = insertelement <3 x i32> %9, i32 %12, i64 2
%add = add <3 x i32> %13, %14
>From 645021dc6a9ff551cb22f1ef0fbd8c38b3fce8b5 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Sat, 28 Sep 2024 18:32:45 +0000
Subject: [PATCH 10/16] adding static inline atributes
---
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 0ad944b9e5d181..80a63b33ea99e6 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -459,27 +459,27 @@ void asuint(double4, out uint4, out uint4);
#elif __is_target_arch(spirv)
-void asuint(double4 D, out uint4 lowbits, out uint4 highbits) {
+static inline void asuint(double4 D, out uint4 lowbits, out uint4 highbits) {
uint4 bottom = __detail::bit_cast<uint4>(D.xy);
uint4 top = __detail::bit_cast<uint4>(D.zw);
lowbits = uint4(bottom.x, bottom.z, top.x, top.z);
highbits = uint4(bottom.y, bottom.w, top.y, top.w);
}
-void asuint(double3 D, out uint3 lowbits, out uint3 highbits) {
+static inline void asuint(double3 D, out uint3 lowbits, out uint3 highbits) {
uint4 bottom = __detail::bit_cast<uint4>(D.xy);
uint2 top = __detail::bit_cast<uint2>(D.z);
lowbits = uint3(bottom.x, bottom.z, top.x);
highbits = uint3(bottom.y, bottom.w, top.y);
}
-void asuint(double2 D, out uint2 lowbits, out uint2 highbits) {
+static inline void asuint(double2 D, out uint2 lowbits, out uint2 highbits) {
uint4 bottom = __detail::bit_cast<uint4>(D.xy);
lowbits = uint2(bottom.x, bottom.z);
highbits = uint2(bottom.y, bottom.w);
}
-void asuint(double D, out uint lowbits, out uint highbits) {
+static inline void asuint(double D, out uint lowbits, out uint highbits) {
uint2 bottom = __detail::bit_cast<uint2>(D);
lowbits = uint(bottom.x);
highbits = uint(bottom.y);
>From dec5db62aba3bb76d033175bb67b2b918f441058 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Tue, 1 Oct 2024 19:40:24 +0000
Subject: [PATCH 11/16] refactoring spirv
---
clang/lib/CodeGen/CGBuiltin.cpp | 8 ++---
clang/lib/CodeGen/CGHLSLRuntime.h | 1 +
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 30 -------------------
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 2 ++
4 files changed, 7 insertions(+), 34 deletions(-)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 04a36bd6606152..74d167c5c168f1 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18966,10 +18966,10 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
auto emitSplitDouble =
- [](CGBuilderTy *Builder, llvm::Value *arg,
+ [](CGBuilderTy *Builder, llvm::Intrinsic::ID intrId, llvm::Value *arg,
llvm::Type *retType) -> std::pair<Value *, Value *> {
CallInst *CI =
- Builder->CreateIntrinsic(retType, llvm::Intrinsic::dx_splitdouble,
+ Builder->CreateIntrinsic(retType, intrId,
{arg}, nullptr, "hlsl.asuint");
Value *arg0 = Builder->CreateExtractValue(CI, 0);
@@ -18987,7 +18987,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
if (!Op0->getType()->isVectorTy()) {
- auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType);
+ auto [arg0, arg1] = emitSplitDouble(&Builder, CGM.getHLSLRuntime().getSplitdoubleIntrinsic(), Op0, retType);
Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
@@ -19006,7 +19006,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
for (uint64_t idx = 0; idx < Op0VecTy->getNumElements(); idx++) {
Value *op = Builder.CreateExtractElement(Op0, idx);
- auto [arg0, arg1] = emitSplitDouble(&Builder, op, retType);
+ auto [arg0, arg1] = emitSplitDouble(&Builder, CGM.getHLSLRuntime().getSplitdoubleIntrinsic(), op, retType);
if (idx == 0) {
inserts.first = Builder.CreateInsertElement(i32VecTy, arg0, idx);
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ff7df41b5c62e7..9fa1a495c5d602 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -88,6 +88,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(Splitdouble, splitdouble);
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 80a63b33ea99e6..a669a673f65a69 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -447,7 +447,6 @@ template <typename T> constexpr uint asuint(T F) {
/// \param D The input double.
/// \param lowbits The output lowbits of D.
/// \param highbits The highbits lowbits D.
-#if __is_target_arch(dxil)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
void asuint(double, out uint, out uint);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
@@ -457,35 +456,6 @@ void asuint(double3, out uint3, out uint3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
void asuint(double4, out uint4, out uint4);
-#elif __is_target_arch(spirv)
-
-static inline void asuint(double4 D, out uint4 lowbits, out uint4 highbits) {
- uint4 bottom = __detail::bit_cast<uint4>(D.xy);
- uint4 top = __detail::bit_cast<uint4>(D.zw);
- lowbits = uint4(bottom.x, bottom.z, top.x, top.z);
- highbits = uint4(bottom.y, bottom.w, top.y, top.w);
-}
-
-static inline void asuint(double3 D, out uint3 lowbits, out uint3 highbits) {
- uint4 bottom = __detail::bit_cast<uint4>(D.xy);
- uint2 top = __detail::bit_cast<uint2>(D.z);
- lowbits = uint3(bottom.x, bottom.z, top.x);
- highbits = uint3(bottom.y, bottom.w, top.y);
-}
-
-static inline void asuint(double2 D, out uint2 lowbits, out uint2 highbits) {
- uint4 bottom = __detail::bit_cast<uint4>(D.xy);
- lowbits = uint2(bottom.x, bottom.z);
- highbits = uint2(bottom.y, bottom.w);
-}
-
-static inline void asuint(double D, out uint lowbits, out uint highbits) {
- uint2 bottom = __detail::bit_cast<uint2>(D);
- lowbits = uint(bottom.x);
- highbits = uint(bottom.y);
-}
-
-#endif
//===----------------------------------------------------------------------===//
// atan builtins
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index d9377fe4b91a1a..b586ad1f4d50ce 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2569,6 +2569,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
}
case Intrinsic::spv_wave_readlane:
return selectWaveReadLaneAt(ResVReg, ResType, I);
+ case Intrinsic::spv_splitdouble:
+ return selectSplitdouble(ResVReg, ResType, I);
case Intrinsic::spv_step:
return selectExtInst(ResVReg, ResType, I, CL::step, GL::Step);
case Intrinsic::spv_radians:
>From 69e0e24d013b5c185a2c6b287c09eab04f3c6bf4 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Wed, 2 Oct 2024 18:57:58 +0000
Subject: [PATCH 12/16] adding dxil codegen
---
clang/lib/CodeGen/CGBuiltin.cpp | 60 ++++++-------------
clang/lib/CodeGen/CGHLSLRuntime.h | 1 -
.../CodeGenHLSL/builtins/splitdouble.hlsl | 26 +++-----
llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 +-
llvm/lib/Target/DirectX/DXILOpLowering.cpp | 6 ++
llvm/test/CodeGen/DirectX/splitdouble.ll | 51 +++-------------
.../test/CodeGen/DirectX/splitdouble_error.ll | 16 +++++
7 files changed, 57 insertions(+), 105 deletions(-)
create mode 100644 llvm/test/CodeGen/DirectX/splitdouble_error.ll
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 74d167c5c168f1..a996cdb8f15815 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18960,67 +18960,41 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
"asuint operands types mismatch");
-
Value *Op0 = EmitScalarExpr(E->getArg(0));
const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
- auto emitSplitDouble =
- [](CGBuilderTy *Builder, llvm::Intrinsic::ID intrId, llvm::Value *arg,
- llvm::Type *retType) -> std::pair<Value *, Value *> {
- CallInst *CI =
- Builder->CreateIntrinsic(retType, intrId,
- {arg}, nullptr, "hlsl.asuint");
-
- Value *arg0 = Builder->CreateExtractValue(CI, 0);
- Value *arg1 = Builder->CreateExtractValue(CI, 1);
-
- return std::make_pair(arg0, arg1);
- };
-
CallArgList Args;
auto [Op1BaseLValue, Op1TmpLValue] =
EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
auto [Op2BaseLValue, Op2TmpLValue] =
EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
- llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
+ if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil) {
- if (!Op0->getType()->isVectorTy()) {
- auto [arg0, arg1] = emitSplitDouble(&Builder, CGM.getHLSLRuntime().getSplitdoubleIntrinsic(), Op0, retType);
-
- Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
- auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
-
- EmitWritebacks(*this, Args);
- return s;
- }
+ llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
- auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
+ if (Op0->getType()->isVectorTy()) {
+ auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
- llvm::VectorType *i32VecTy = llvm::VectorType::get(
- Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
+ llvm::VectorType *i32VecTy = llvm::VectorType::get(
+ Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
+ retType = llvm::StructType::get(i32VecTy, i32VecTy);
+ }
- std::pair<Value *, Value *> inserts = std::make_pair(nullptr, nullptr);
+ CallInst *CI =
+ Builder.CreateIntrinsic(retType, Intrinsic::dx_splitdouble, {Op0},
+ nullptr, "hlsl.splitdouble");
- for (uint64_t idx = 0; idx < Op0VecTy->getNumElements(); idx++) {
- Value *op = Builder.CreateExtractElement(Op0, idx);
+ Value *arg0 = Builder.CreateExtractValue(CI, 0);
+ Value *arg1 = Builder.CreateExtractValue(CI, 1);
- auto [arg0, arg1] = emitSplitDouble(&Builder, CGM.getHLSLRuntime().getSplitdoubleIntrinsic(), op, retType);
+ Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+ auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
- if (idx == 0) {
- inserts.first = Builder.CreateInsertElement(i32VecTy, arg0, idx);
- inserts.second = Builder.CreateInsertElement(i32VecTy, arg1, idx);
- } else {
- inserts.first = Builder.CreateInsertElement(inserts.first, arg0, idx);
- inserts.second = Builder.CreateInsertElement(inserts.second, arg1, idx);
- }
+ EmitWritebacks(*this, Args);
+ return s;
}
-
- Builder.CreateStore(inserts.first, Op1TmpLValue.getAddress());
- auto *s = Builder.CreateStore(inserts.second, Op2TmpLValue.getAddress());
- EmitWritebacks(*this, Args);
- return s;
}
}
return nullptr;
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 9fa1a495c5d602..ff7df41b5c62e7 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -88,7 +88,6 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
- GENERATE_HLSL_INTRINSIC_FUNCTION(Splitdouble, splitdouble);
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
index b937bb5d4d343d..4f3a2330af924e 100644
--- a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -1,33 +1,23 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
-// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv--vulkan-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefix=SPIRV
-// CHECK: define {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
-// CHECK: [[VALRET:%hlsl.asuint.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
+// CHECK: define {{.*}} i32 {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
-// SPIRV: define spir_func {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
-// SPIRV-NOT: @llvm.dx.splitdouble
-// SPIRV: [[REG:%.*]] = load double, ptr [[VALD]].addr
-// SPIRV: call spir_func void {{.*}}asuint{{.*}}(double {{.*}} [[REG]], {{.*}})
-float test_scalar(double D) {
+uint test_scalar(double D) {
uint A, B;
asuint(D, A, B);
return A + B;
}
-// CHECK: define {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
-// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[VALD]], i64 [[VALIDX:[0-3]]]
-// CHECK-NEXT: [[VALRET:%hlsl.asuint.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALREG]])
-// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
-// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
-// SPIRV: define spir_func {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
-// SPIRV-NOT: @llvm.dx.splitdouble
-// SPIRV: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr
-// SPIRV: call spir_func void {{.*}}asuint{{.*}}(<3 x double> {{.*}} [[REG]], {{.*}})
-float3 test_vector(double3 D) {
+// CHECK: define {{.*}} <3 x i32> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = {{.*}} call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> [[VALD]])
+// CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 1
+uint3 test_vector(double3 D) {
uint3 A, B;
asuint(D, A, B);
return A + B;
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 3a5d6349f85622..34ac423e278bba 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -93,6 +93,6 @@ def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>
def int_dx_splitdouble : DefaultAttrsIntrinsic<
[llvm_anyint_ty, LLVMMatchType<0>],
- [llvm_double_ty],
+ [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>],
[IntrNoMem, IntrWillReturn]>;
}
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 340914ff2cf422..344f7bb517c2bc 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -493,6 +493,12 @@ class OpLowerer {
Value *Arg0 = CI->getArgOperand(0);
+ if (Arg0->getType()->isVectorTy()) {
+ return make_error<StringError>(
+ "splitdouble doesn't support lowering vector types.",
+ inconvertibleErrorCode());
+ }
+
Type *NewRetTy = OpBuilder.getResSplitDoubleType(M.getContext());
std::array<Value *, 1> Args{Arg0};
diff --git a/llvm/test/CodeGen/DirectX/splitdouble.ll b/llvm/test/CodeGen/DirectX/splitdouble.ll
index c62b7dd2371ba2..6da3b5797b4cba 100644
--- a/llvm/test/CodeGen/DirectX/splitdouble.ll
+++ b/llvm/test/CodeGen/DirectX/splitdouble.ll
@@ -1,50 +1,17 @@
-; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
-; RUN: opt -S --scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
+; RUN: opt -S --scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
-; Make sure DXILOpLowering is correctly generating the dxil op code call, with and without scalarizer.
+; Make sure DXILOpLowering is correctly generating the dxil op, with and without scalarizer.
-; CHECK-LABEL: define noundef float @test_scalar_double_split
-define noundef float @test_scalar_double_split(double noundef %D) local_unnamed_addr {
+; CHECK-LABEL: define noundef i32 @test_scalar_double_split
+define noundef i32 @test_scalar_double_split(double noundef %D) local_unnamed_addr {
entry:
; CHECK: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double %D)
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
- %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
- %0 = extractvalue { i32, i32 } %hlsl.asuint, 0
- %1 = extractvalue { i32, i32 } %hlsl.asuint, 1
+ %hlsl.splitdouble = call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
+ %0 = extractvalue { i32, i32 } %hlsl.splitdouble, 0
+ %1 = extractvalue { i32, i32 } %hlsl.splitdouble, 1
%add = add i32 %0, %1
- %conv = uitofp i32 %add to float
- ret float %conv
-}
-
-declare <2 x i32> @llvm.dx.splitdouble.v2i32(double) #1
-
-
-; CHECK-LABEL: define noundef <3 x float> @test_vector_double_split
-define noundef <3 x float> @test_vector_double_split(<3 x double> noundef %D) local_unnamed_addr {
-entry:
- %0 = extractelement <3 x double> %D, i64 0
- ; CHECK-COUNT-3: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double {{.*}})
- ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
- ; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
- %hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %0)
- %1 = extractvalue { i32, i32 } %hlsl.asuint, 0
- %2 = extractvalue { i32, i32 } %hlsl.asuint, 1
- %3 = insertelement <3 x i32> poison, i32 %1, i64 0
- %4 = insertelement <3 x i32> poison, i32 %2, i64 0
- %5 = extractelement <3 x double> %D, i64 1
- %hlsl.asuint2 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %5)
- %6 = extractvalue { i32, i32 } %hlsl.asuint2, 0
- %7 = extractvalue { i32, i32 } %hlsl.asuint2, 1
- %8 = insertelement <3 x i32> %3, i32 %6, i64 1
- %9 = insertelement <3 x i32> %4, i32 %7, i64 1
- %10 = extractelement <3 x double> %D, i64 2
- %hlsl.asuint3 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %10)
- %11 = extractvalue { i32, i32 } %hlsl.asuint3, 0
- %12 = extractvalue { i32, i32 } %hlsl.asuint3, 1
- %13 = insertelement <3 x i32> %8, i32 %11, i64 2
- %14 = insertelement <3 x i32> %9, i32 %12, i64 2
- %add = add <3 x i32> %13, %14
- %conv = uitofp <3 x i32> %add to <3 x float>
- ret <3 x float> %conv
+ ret i32 %add
}
diff --git a/llvm/test/CodeGen/DirectX/splitdouble_error.ll b/llvm/test/CodeGen/DirectX/splitdouble_error.ll
new file mode 100644
index 00000000000000..acfd52b24c9cc3
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/splitdouble_error.ll
@@ -0,0 +1,16 @@
+; RUN: not opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
+
+; DXIL operation splitdouble doesn't support vector types.
+; CHECK: in function test_vector_double_split
+; CHECK-SAME: splitdouble doesn't support lowering vector types.
+
+define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %D) local_unnamed_addr {
+entry:
+ %hlsl.splitdouble = tail call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %D)
+ %0 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.splitdouble, 0
+ %1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.splitdouble, 1
+ %add = add <3 x i32> %0, %1
+ ret <3 x i32> %add
+}
+
+declare { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double>)
>From 91cff1b31b5ce932d9365afc9364a9ed0289cb1c Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Wed, 2 Oct 2024 22:07:42 +0000
Subject: [PATCH 13/16] remove spirv lowering
---
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 2 --
1 file changed, 2 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b586ad1f4d50ce..d9377fe4b91a1a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2569,8 +2569,6 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
}
case Intrinsic::spv_wave_readlane:
return selectWaveReadLaneAt(ResVReg, ResType, I);
- case Intrinsic::spv_splitdouble:
- return selectSplitdouble(ResVReg, ResType, I);
case Intrinsic::spv_step:
return selectExtInst(ResVReg, ResType, I, CL::step, GL::Step);
case Intrinsic::spv_radians:
>From dbf19f68952677e8665f7c0996c459f5f64a5a76 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 3 Oct 2024 01:24:08 +0000
Subject: [PATCH 14/16] adding spirv codegen
---
clang/lib/CodeGen/CGBuiltin.cpp | 22 ++++++++++++++++++++++
1 file changed, 22 insertions(+)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index a996cdb8f15815..2165bda5588ede 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18994,7 +18994,29 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
EmitWritebacks(*this, Args);
return s;
+ }
+
+
+ if(!Op0->getType()->isVectorTy()){
+ FixedVectorType *destTy = FixedVectorType::get(Int32Ty, 2);
+ Value *bitcast = Builder.CreateBitCast(Op0, destTy);
+
+ Value *arg0 = Builder.CreateExtractElement(bitcast, 0.0);
+ Value *arg1 = Builder.CreateExtractElement(bitcast, 1.0);
+
+ Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+ auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
+
+ EmitWritebacks(*this, Args);
+ return s;
+ }
+
+ auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
+
+ for(int idx = 0 ; idx < Op0VecTy -> getNumElements(); idx += 2){
+
}
+
}
}
return nullptr;
>From 0e3e8873048bf7c356f3fb89cf22a91f0730a0d8 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Fri, 4 Oct 2024 01:52:16 +0000
Subject: [PATCH 15/16] adding spirv codegen
---
clang/lib/CodeGen/CGBuiltin.cpp | 26 ++++++++++++++++++-
.../CodeGenHLSL/builtins/splitdouble.hlsl | 11 ++++++++
.../SPIRV/hlsl-intrinsics/splitdouble.ll | 16 ++++++++++++
.../hlsl-intrinsics/splitdouble_vector.ll | 14 ++++++++++
4 files changed, 66 insertions(+), 1 deletion(-)
create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 2165bda5588ede..09c58d04a17145 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -34,8 +34,10 @@
#include "clang/Frontend/FrontendDiagnostic.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
@@ -19013,10 +19015,32 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
- for(int idx = 0 ; idx < Op0VecTy -> getNumElements(); idx += 2){
+ int numElements = Op0VecTy -> getNumElements() * 2;
+
+ FixedVectorType *destTy = FixedVectorType::get(Int32Ty, numElements);
+ Value *bitcast = Builder.CreateBitCast(Op0, destTy);
+
+ SmallVector<int> lowbitsIndex;
+ SmallVector<int> highbitsIndex;
+
+ for(int idx = 0; idx < numElements; idx += 2){
+ lowbitsIndex.push_back(idx);
+ }
+
+ for(int idx = 1; idx < numElements; idx += 2){
+ highbitsIndex.push_back(idx);
}
+ Value *arg0 = Builder.CreateShuffleVector(bitcast, lowbitsIndex);
+ Value *arg1 = Builder.CreateShuffleVector(bitcast, highbitsIndex);
+
+ Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+ auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
+
+ EmitWritebacks(*this, Args);
+ return s;
+
}
}
return nullptr;
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
index 4f3a2330af924e..e1f42824cfe5ed 100644
--- a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -1,4 +1,5 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv-vulkan-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPIRV
@@ -6,6 +7,11 @@
// CHECK: [[VALRET:%.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} i32 {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[CAST:%.*]] = bitcast double [[VALD]] to <2 x i32>
+// SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 0
+// SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 1
uint test_scalar(double D) {
uint A, B;
asuint(D, A, B);
@@ -17,6 +23,11 @@ uint test_scalar(double D) {
// CHECK: [[VALRET:%.*]] = {{.*}} call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> [[VALD]])
// CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 0
// CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} <3 x i32> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[CAST:%.*]] = bitcast <3 x double> [[VALD]] to <6 x i32>
+// SPIRV-NEXT: shufflevector <6 x i32> [[CAST]], <6 x i32> poison, <3 x i32> <i32 0, i32 2, i32 4>
+// SPIRV-NEXT: shufflevector <6 x i32> [[CAST]], <6 x i32> poison, <3 x i32> <i32 1, i32 3, i32 5>
uint3 test_vector(double3 D) {
uint3 A, B;
asuint(D, A, B);
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
new file mode 100644
index 00000000000000..c057042c0d142e
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
@@ -0,0 +1,16 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure lowering is correctly generating spirv code.
+
+define spir_func noundef i32 @test_scalar(double noundef %D) local_unnamed_addr {
+entry:
+ ; CHECK: %[[#]] = OpBitcast %[[#]] %[[#]]
+ %0 = bitcast double %D to <2 x i32>
+ ; CHECK: %[[#]] = OpCompositeExtract %[[#]] %[[#]] 0
+ %1 = extractelement <2 x i32> %0, i64 0
+ ; CHECK: %[[#]] = OpCompositeExtract %[[#]] %[[#]] 1
+ %2 = extractelement <2 x i32> %0, i64 1
+ %add = add i32 %1, %2
+ ret i32 %add
+}
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll
new file mode 100644
index 00000000000000..58bd5a046ff3d1
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll
@@ -0,0 +1,14 @@
+; RUN: opt -S -scalarizer -mtriple=spirv-vulkan-library %s 2>&1 | llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown -o - | FileCheck %s
+
+; SPIRV lowering for splitdouble should relly on the scalarizer.
+
+define spir_func noundef <3 x i32> @test_vector(<3 x double> noundef %D) local_unnamed_addr {
+entry:
+ ; CHECK-COUNT-3: %[[#]] = OpBitcast %[[#]] %[[#]]
+ ; CHECK-COUNT-3: %[[#]] = OpCompositeExtract %[[#]] %[[#]] [[0-2]]
+ %0 = bitcast <3 x double> %D to <6 x i32>
+ %1 = shufflevector <6 x i32> %0, <6 x i32> poison, <3 x i32> <i32 0, i32 2, i32 4>
+ %2 = shufflevector <6 x i32> %0, <6 x i32> poison, <3 x i32> <i32 1, i32 3, i32 5>
+ %add = add <3 x i32> %1, %2
+ ret <3 x i32> %add
+}
>From 119f454f46180abe8581aa9044dafe1e658f1c89 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Mon, 7 Oct 2024 22:15:53 +0000
Subject: [PATCH 16/16] addressing PR comments
---
clang/include/clang/Basic/Builtins.td | 14 +-
clang/lib/CodeGen/CGBuiltin.cpp | 247 ++++++++++++------
clang/lib/CodeGen/CGCall.cpp | 14 +-
clang/lib/CodeGen/CGExpr.cpp | 7 +-
clang/lib/CodeGen/CodeGenFunction.h | 9 +-
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 11 +-
clang/lib/Sema/SemaHLSL.cpp | 24 +-
.../CodeGenHLSL/builtins/splitdouble.hlsl | 83 +++++-
llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 +-
llvm/lib/Target/DirectX/DXIL.td | 4 +-
llvm/lib/Target/DirectX/DXILOpBuilder.cpp | 10 +-
llvm/lib/Target/DirectX/DXILOpBuilder.h | 2 +-
llvm/lib/Target/DirectX/DXILOpLowering.cpp | 73 +++---
.../DirectX/DirectXTargetTransformInfo.cpp | 77 +++---
.../test/CodeGen/DirectX/splitdouble_error.ll | 7 +-
.../SPIRV/hlsl-intrinsics/splitdouble.ll | 40 ++-
.../hlsl-intrinsics/splitdouble_vector.ll | 14 -
17 files changed, 392 insertions(+), 246 deletions(-)
delete mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b35a24eda6da44..9bd67e0cefebc3 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4871,20 +4871,8 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}
-def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
- let Spellings = ["__builtin_hlsl_elementwise_radians"];
- let Attributes = [NoThrow, Const];
- let Prototype = "void(...)";
-}
-
-def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
- let Spellings = ["__builtin_hlsl_elementwise_radians"];
- let Attributes = [NoThrow, Const];
- let Prototype = "void(...)";
-}
-
def HLSLSplitDouble: LangBuiltin<"HLSL_LANG"> {
- let Spellings = ["__builtin_hlsl_splitdouble"];
+ let Spellings = ["__builtin_hlsl_elementwise_splitdouble"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 09c58d04a17145..d57a7f1d259d99 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -17,6 +17,7 @@
#include "CGObjCRuntime.h"
#include "CGOpenCLRuntime.h"
#include "CGRecordLayout.h"
+#include "CGValue.h"
#include "CodeGenFunction.h"
#include "CodeGenModule.h"
#include "ConstantEmitter.h"
@@ -25,8 +26,10 @@
#include "clang/AST/ASTContext.h"
#include "clang/AST/Attr.h"
#include "clang/AST/Decl.h"
+#include "clang/AST/Expr.h"
#include "clang/AST/OSLog.h"
#include "clang/AST/OperationKinds.h"
+#include "clang/AST/Type.h"
#include "clang/Basic/TargetBuiltins.h"
#include "clang/Basic/TargetInfo.h"
#include "clang/Basic/TargetOptions.h"
@@ -34,14 +37,11 @@
#include "clang/Frontend/FrontendDiagnostic.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/ADT/SmallPtrSet.h"
-#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/DataLayout.h"
-#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsAArch64.h"
@@ -70,6 +70,7 @@
#include "llvm/TargetParser/X86TargetParser.h"
#include <optional>
#include <sstream>
+#include <utility>
using namespace clang;
using namespace CodeGen;
@@ -98,6 +99,163 @@ static void initializeAlloca(CodeGenFunction &CGF, AllocaInst *AI, Value *Size,
I->addAnnotationMetadata("auto-init");
}
+static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {
+ Value *Op0 = CGF->EmitScalarExpr(E->getArg(0));
+ const auto *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
+ const auto *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
+
+ CallArgList Args;
+ LValue Op1TmpLValue =
+ CGF->EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+ LValue Op2TmpLValue =
+ CGF->EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+
+ if (CGF->getTarget().getCXXABI().areArgsDestroyedLeftToRightInCallee())
+ Args.reverseWritebacks();
+
+ auto EmitVectorCode =
+ [](Value *Op, CGBuilderTy *Builder,
+ FixedVectorType *DestTy) -> std::pair<Value *, Value *> {
+ Value *bitcast = Builder->CreateBitCast(Op, DestTy);
+
+ SmallVector<int> LowbitsIndex;
+ SmallVector<int> HighbitsIndex;
+
+ for (unsigned int Idx = 0; Idx < DestTy->getNumElements(); Idx += 2) {
+ LowbitsIndex.push_back(Idx);
+ HighbitsIndex.push_back(Idx + 1);
+ }
+
+ Value *Arg0 = Builder->CreateShuffleVector(bitcast, LowbitsIndex);
+ Value *Arg1 = Builder->CreateShuffleVector(bitcast, HighbitsIndex);
+
+ return std::make_pair(Arg0, Arg1);
+ };
+
+ Value *LastInst = nullptr;
+
+ if (CGF->CGM.getTarget().getTriple().isDXIL()) {
+
+ llvm::Type *RetElementTy = CGF->Int32Ty;
+ if (auto *Op0VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>())
+ RetElementTy = llvm::VectorType::get(
+ CGF->Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
+ auto *RetTy = llvm::StructType::get(RetElementTy, RetElementTy);
+
+ CallInst *CI = CGF->Builder.CreateIntrinsic(
+ RetTy, Intrinsic::dx_splitdouble, {Op0}, nullptr, "hlsl.splitdouble");
+
+ Value *Arg0 = CGF->Builder.CreateExtractValue(CI, 0);
+ Value *Arg1 = CGF->Builder.CreateExtractValue(CI, 1);
+
+ CGF->Builder.CreateStore(Arg0, Op1TmpLValue.getAddress());
+ LastInst = CGF->Builder.CreateStore(Arg1, Op2TmpLValue.getAddress());
+
+ } else {
+
+ assert(!CGF->CGM.getTarget().getTriple().isDXIL() &&
+ "For non-DXIL targets we generate the instructions");
+
+ if (!Op0->getType()->isVectorTy()) {
+ FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);
+ Value *Bitcast = CGF->Builder.CreateBitCast(Op0, DestTy);
+
+ Value *Arg0 = CGF->Builder.CreateExtractElement(Bitcast, 0.0);
+ Value *Arg1 = CGF->Builder.CreateExtractElement(Bitcast, 1.0);
+
+ CGF->Builder.CreateStore(Arg0, Op1TmpLValue.getAddress());
+ LastInst = CGF->Builder.CreateStore(Arg1, Op2TmpLValue.getAddress());
+ } else {
+
+ const auto *TargTy = E->getArg(0)->getType()->getAs<clang::VectorType>();
+
+ int NumElements = TargTy->getNumElements();
+
+ FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 4);
+ if (NumElements == 1) {
+ FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);
+ Value *Bitcast = CGF->Builder.CreateBitCast(Op0, DestTy);
+
+ Value *Arg0 = CGF->Builder.CreateExtractElement(Bitcast, 0.0);
+ Value *Arg1 = CGF->Builder.CreateExtractElement(Bitcast, 1.0);
+
+ CGF->Builder.CreateStore(Arg0, Op1TmpLValue.getAddress());
+ LastInst = CGF->Builder.CreateStore(Arg1, Op2TmpLValue.getAddress());
+ } else if (NumElements == 2) {
+ auto [LowBits, HighBits] = EmitVectorCode(Op0, &CGF->Builder, DestTy);
+
+ CGF->Builder.CreateStore(LowBits, Op1TmpLValue.getAddress());
+ LastInst =
+ CGF->Builder.CreateStore(HighBits, Op2TmpLValue.getAddress());
+ } else {
+
+ SmallVector<std::pair<Value *, Value *>> EmitedValuePairs;
+
+ int isOdd = NumElements % 2;
+ int NumEvenElements = NumElements - isOdd;
+
+ for (int It = 0; It < NumEvenElements; It += 2) {
+ // Due to existing restrictions to SPIR-V and splitdouble,
+ // all shufflevector operations, should return vectors of
+ // the same size, up to 4. Such introduce and edge case
+ // when we got odd sized vectors, which will require
+ // an additional dummy value, such is masked out in a later
+ // stage of this code.
+ auto Shuff = CGF->Builder.CreateShuffleVector(Op0, {It, It + 1});
+ std::pair<Value *, Value *> ValuePair =
+ EmitVectorCode(Shuff, &CGF->Builder, DestTy);
+ EmitedValuePairs.push_back(ValuePair);
+ }
+
+ if (isOdd == 1) {
+ FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);
+ auto Shuff = CGF->Builder.CreateShuffleVector(Op0, {NumEvenElements});
+ std::pair<Value *, Value *> ValuePair =
+ EmitVectorCode(Shuff, &CGF->Builder, DestTy);
+ EmitedValuePairs.push_back(ValuePair);
+ }
+
+ SmallVector<int> Index = {0, 1};
+
+ auto arg0 = EmitedValuePairs[0].first;
+ auto arg1 = EmitedValuePairs[0].second;
+
+ auto EvenSizedPairs = EmitedValuePairs.size() - isOdd;
+
+ for (int It = 1; It < EvenSizedPairs; It++) {
+ int CurIndexSize = Index.size();
+ Index.insert(Index.end(), {CurIndexSize, CurIndexSize + 1});
+ arg0 = CGF->Builder.CreateShuffleVector(
+ arg0, EmitedValuePairs[It].first, Index);
+ arg1 = CGF->Builder.CreateShuffleVector(
+ arg1, EmitedValuePairs[It].second, Index);
+ }
+
+ if (isOdd == 1) {
+ int CurIndexSize = Index.size();
+
+ auto extendedLowerBits = CGF->Builder.CreateShuffleVector(
+ EmitedValuePairs[EvenSizedPairs].first, {0, 0});
+
+ auto extendedHighBits = CGF->Builder.CreateShuffleVector(
+ EmitedValuePairs[EvenSizedPairs].second, {0, 0});
+ Index.insert(Index.end(), {CurIndexSize});
+
+ arg0 =
+ CGF->Builder.CreateShuffleVector(arg0, extendedLowerBits, Index);
+ arg1 =
+ CGF->Builder.CreateShuffleVector(arg1, extendedHighBits, Index);
+ }
+
+ CGF->Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+ LastInst = CGF->Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
+ }
+ }
+ }
+ CGF->EmitWritebacks(*CGF, Args);
+ return LastInst;
+}
+
/// getBuiltinLibFunction - Given a builtin id for a function like
/// "__builtin_fabsf", return a Function* for "fabsf".
llvm::Constant *CodeGenModule::getBuiltinLibFunction(const FunctionDecl *FD,
@@ -18955,92 +19113,13 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
nullptr, "hlsl.radians");
}
- // This should only be called when targeting DXIL
- case Builtin::BI__builtin_hlsl_splitdouble: {
+ case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
"asuint operands types mismatch");
- Value *Op0 = EmitScalarExpr(E->getArg(0));
- const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
- const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
-
- CallArgList Args;
- auto [Op1BaseLValue, Op1TmpLValue] =
- EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
- auto [Op2BaseLValue, Op2TmpLValue] =
- EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
-
- if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil) {
-
- llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
-
- if (Op0->getType()->isVectorTy()) {
- auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
-
- llvm::VectorType *i32VecTy = llvm::VectorType::get(
- Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
- retType = llvm::StructType::get(i32VecTy, i32VecTy);
- }
-
- CallInst *CI =
- Builder.CreateIntrinsic(retType, Intrinsic::dx_splitdouble, {Op0},
- nullptr, "hlsl.splitdouble");
-
- Value *arg0 = Builder.CreateExtractValue(CI, 0);
- Value *arg1 = Builder.CreateExtractValue(CI, 1);
-
- Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
- auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
-
- EmitWritebacks(*this, Args);
- return s;
- }
-
-
- if(!Op0->getType()->isVectorTy()){
- FixedVectorType *destTy = FixedVectorType::get(Int32Ty, 2);
- Value *bitcast = Builder.CreateBitCast(Op0, destTy);
-
- Value *arg0 = Builder.CreateExtractElement(bitcast, 0.0);
- Value *arg1 = Builder.CreateExtractElement(bitcast, 1.0);
-
- Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
- auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
-
- EmitWritebacks(*this, Args);
- return s;
- }
-
- auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
-
- int numElements = Op0VecTy -> getNumElements() * 2;
-
- FixedVectorType *destTy = FixedVectorType::get(Int32Ty, numElements);
-
- Value *bitcast = Builder.CreateBitCast(Op0, destTy);
-
- SmallVector<int> lowbitsIndex;
- SmallVector<int> highbitsIndex;
-
- for(int idx = 0; idx < numElements; idx += 2){
- lowbitsIndex.push_back(idx);
- }
-
- for(int idx = 1; idx < numElements; idx += 2){
- highbitsIndex.push_back(idx);
- }
-
- Value *arg0 = Builder.CreateShuffleVector(bitcast, lowbitsIndex);
- Value *arg1 = Builder.CreateShuffleVector(bitcast, highbitsIndex);
-
- Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
- auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
-
- EmitWritebacks(*this, Args);
- return s;
-
+ return handleHlslSplitdouble(E, this);
}
}
return nullptr;
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 096bbafa4cc694..afa249b42bc630 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -40,6 +40,7 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Type.h"
+#include "llvm/Support/Path.h"
#include "llvm/Transforms/Utils/Local.h"
#include <optional>
using namespace clang;
@@ -4197,7 +4198,7 @@ static void emitWriteback(CodeGenFunction &CGF,
// Release the old value.
CGF.EmitARCRelease(oldValue, srcLV.isARCPreciseLifetime());
- // Otherwise, we can just do a normal lvalue store.
+ // Otherwise, we can just do a normal lvalue store.
} else {
CGF.EmitStoreThroughLValue(RValue::get(value), srcLV);
}
@@ -4207,12 +4208,6 @@ static void emitWriteback(CodeGenFunction &CGF,
CGF.EmitBlock(contBB);
}
-static void emitWritebacks(CodeGenFunction &CGF,
- const CallArgList &args) {
- for (const auto &I : args.writebacks())
- emitWriteback(CGF, I);
-}
-
static void deactivateArgCleanupsBeforeCall(CodeGenFunction &CGF,
const CallArgList &CallArgs) {
ArrayRef<CallArgList::CallArgCleanup> Cleanups =
@@ -4683,7 +4678,8 @@ void CallArg::copyInto(CodeGenFunction &CGF, Address Addr) const {
void CodeGenFunction::EmitWritebacks(CodeGenFunction &CGF,
const CallArgList &args) {
- emitWritebacks(CGF, args);
+ for (const auto &I : args.writebacks())
+ emitWriteback(CGF, I);
}
void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
@@ -5898,7 +5894,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
// Emit any call-associated writebacks immediately. Arguably this
// should happen after any return-value munging.
if (CallArgs.hasWritebacks())
- emitWritebacks(*this, CallArgs);
+ CodeGenFunction::EmitWritebacks(*this, CallArgs);
// The stack cleanup for inalloca arguments has to run out of the normal
// lexical order, so deactivate it and run it manually here.
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 8757b08079dab6..d2bfdf6437142e 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -5478,9 +5478,8 @@ CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) {
return std::make_pair(BaseLV, TempLV);
}
-std::pair<LValue, LValue>
-CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
- QualType Ty) {
+LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
+ CallArgList &Args, QualType Ty) {
auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty);
@@ -5495,7 +5494,7 @@ CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(),
LifetimeSize);
Args.add(RValue::get(TmpAddr, *this), Ty);
- return std::make_pair(BaseLV, TempLV);
+ return TempLV;
}
LValue
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 2b48d55a98c9a0..6e2a4cca5e51f9 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4301,8 +4301,8 @@ class CodeGenFunction : public CodeGenTypeCache {
std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E,
QualType Ty);
- std::pair<LValue, LValue> EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
- CallArgList &Args, QualType Ty);
+ LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
+ QualType Ty);
Address EmitExtVectorElementLValue(LValue V);
@@ -5153,7 +5153,10 @@ class CodeGenFunction : public CodeGenTypeCache {
unsigned ParmNum);
/// EmitWriteback - Emit callbacks for function.
- void EmitWritebacks(CodeGenFunction &CGF, const CallArgList &args);
+ void EmitWritebacks(CodeGenFunction &CGF, const CallArgList &Args);
+
+ void EmitWriteback(CodeGenFunction &CGF,
+ const CallArgList::Writeback &writeback);
/// EmitCallArg - Emit a single call argument.
void EmitCallArg(CallArgList &args, const Expr *E, QualType ArgType);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index a669a673f65a69..8ade4b27f360fb 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -446,17 +446,16 @@ template <typename T> constexpr uint asuint(T F) {
/// \brief Split and interprets the lowbits and highbits of double D into uints.
/// \param D The input double.
/// \param lowbits The output lowbits of D.
-/// \param highbits The highbits lowbits D.
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
+/// \param highbits The output highbits of D.
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
void asuint(double, out uint, out uint);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
void asuint(double2, out uint2, out uint2);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
void asuint(double3, out uint3, out uint3);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_splitdouble)
void asuint(double4, out uint4, out uint4);
-
//===----------------------------------------------------------------------===//
// atan builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 4f7209ca7bd0a7..2a206b84b5bc73 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1698,7 +1698,7 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
return true;
}
-bool CheckArgTypeIsIncorrect(
+bool CheckArgTypeIsCorrect(
Sema *S, Expr *Arg, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
QualType PassedType = Arg->getType();
@@ -1718,7 +1718,7 @@ bool CheckArgsTypesAreCorrect(
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
Expr *Arg = TheCall->getArg(i);
- if (CheckArgTypeIsIncorrect(S, Arg, ExpectedType, Check)) {
+ if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
return true;
}
}
@@ -2083,34 +2083,32 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
- case Builtin::BI__builtin_hlsl_splitdouble: {
+ case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;
Expr *Op0 = TheCall->getArg(0);
- auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool {
+ auto CheckIsDouble = [](clang::QualType PassedType) -> bool {
return !PassedType->hasFloatingRepresentation();
};
- if (CheckArgTypeIsIncorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
- CheckIsNotDouble)) {
+ if (CheckArgTypeIsCorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
+ CheckIsDouble))
return true;
- }
Expr *Op1 = TheCall->getArg(1);
Expr *Op2 = TheCall->getArg(2);
- auto CheckIsNotUint = [](clang::QualType PassedType) -> bool {
+ auto CheckIsUint = [](clang::QualType PassedType) -> bool {
return !PassedType->hasUnsignedIntegerRepresentation();
};
- if (CheckArgTypeIsIncorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
- CheckIsNotUint) ||
- CheckArgTypeIsIncorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
- CheckIsNotUint)) {
+ if (CheckArgTypeIsCorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
+ CheckIsUint) ||
+ CheckArgTypeIsCorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
+ CheckIsUint))
return true;
- }
break;
}
diff --git a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
index e1f42824cfe5ed..9568a45eeee93c 100644
--- a/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/splitdouble.hlsl
@@ -1,5 +1,5 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
-// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv-vulkan-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPIRV
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv-vulkan-library %s -fnative-half-type -emit-llvm -O0 -o - | FileCheck %s --check-prefix=SPIRV
@@ -9,7 +9,8 @@
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
// SPIRV: define spir_func {{.*}} i32 {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
// SPIRV-NOT: @llvm.dx.splitdouble.i32
-// SPIRV: [[CAST:%.*]] = bitcast double [[VALD]] to <2 x i32>
+// SPIRV: [[REG:%.*]] = load double, ptr [[VALD]].addr, align 8
+// SPIRV-NEXT: [[CAST:%.*]] = bitcast double [[REG]] to <2 x i32>
// SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 0
// SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 1
uint test_scalar(double D) {
@@ -18,18 +19,84 @@ uint test_scalar(double D) {
return A + B;
}
+// CHECK: define {{.*}} i32 {{.*}}test_double1{{.*}}(<1 x double> {{.*}} [[VALD:%.*]])
+// CHECK: [[TRUNC:%.*]] = extractelement <1 x double> %D, i64 0
+// CHECK-NEXT: [[VALRET:%.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[TRUNC]])
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} i32 {{.*}}test_double1{{.*}}(<1 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[REG:%.*]] = load <1 x double>, ptr [[VALD]].addr, align 8
+// SPIRV-NEXT: [[TRUNC:%.*]] = extractelement <1 x double> %1, i64 0
+// SPIRV-NEXT: [[CAST:%.*]] = bitcast double [[TRUNC]] to <2 x i32>
+// SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 0
+// SPIRV-NEXT: extractelement <2 x i32> [[CAST]], i64 1
+uint test_double1(double1 D) {
+ uint A, B;
+ asuint(D, A, B);
+ return A + B;
+}
-// CHECK: define {{.*}} <3 x i32> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// CHECK: define {{.*}} <2 x i32> {{.*}}test_vector2{{.*}}(<2 x double> {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = {{.*}} call { <2 x i32>, <2 x i32> } @llvm.dx.splitdouble.v2i32(<2 x double> [[VALD]])
+// CHECK-NEXT: extractvalue { <2 x i32>, <2 x i32> } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { <2 x i32>, <2 x i32> } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} <2 x i32> {{.*}}test_vector2{{.*}}(<2 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[REG:%.*]] = load <2 x double>, ptr [[VALD]].addr, align 16
+// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <2 x double> [[REG]] to <4 x i32>
+// SPIRV-NEXT: [[SHUF1:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
+// SPIRV-NEXT: [[SHUF2:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 1, i32 3>
+uint2 test_vector2(double2 D) {
+ uint2 A, B;
+ asuint(D, A, B);
+ return A + B;
+}
+
+// CHECK: define {{.*}} <3 x i32> {{.*}}test_vector3{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
// CHECK: [[VALRET:%.*]] = {{.*}} call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> [[VALD]])
// CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 0
// CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 1
-// SPIRV: define spir_func {{.*}} <3 x i32> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
+// SPIRV: define spir_func {{.*}} <3 x i32> {{.*}}test_vector3{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
// SPIRV-NOT: @llvm.dx.splitdouble.i32
-// SPIRV: [[CAST:%.*]] = bitcast <3 x double> [[VALD]] to <6 x i32>
-// SPIRV-NEXT: shufflevector <6 x i32> [[CAST]], <6 x i32> poison, <3 x i32> <i32 0, i32 2, i32 4>
-// SPIRV-NEXT: shufflevector <6 x i32> [[CAST]], <6 x i32> poison, <3 x i32> <i32 1, i32 3, i32 5>
-uint3 test_vector(double3 D) {
+// SPIRV: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr, align 32
+// SPIRV-NEXT: [[VALRET1:%.*]] = shufflevector <3 x double> [[REG]], <3 x double> poison, <2 x i32> <i32 0, i32 1>
+// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <2 x double> [[VALRET1]] to <4 x i32>
+// SPIRV-NEXT: [[SHUF1:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
+// SPIRV-NEXT: [[SHUF2:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 1, i32 3>
+// SPIRV-NEXT: [[VALRET2:%.*]] = shufflevector <3 x double> [[REG]], <3 x double> poison, <1 x i32> <i32 2>
+// SPIRV-NEXT: [[CAST2:%.*]] = bitcast <1 x double> [[VALRET2]] to <2 x i32>
+// SPIRV-NEXT: [[SHUF3:%.*]] = shufflevector <2 x i32> [[CAST2]], <2 x i32> poison, <1 x i32> zeroinitializer
+// SPIRV-NEXT: [[SHUF4:%.*]] = shufflevector <2 x i32> [[CAST2]], <2 x i32> poison, <1 x i32> <i32 1>
+// SPIRV-NEXT: [[SHUF5:%.*]] = shufflevector <1 x i32> [[SHUF3]], <1 x i32> poison, <2 x i32> zeroinitializer
+// SPIRV-NEXT: [[SHUF6:%.*]] = shufflevector <1 x i32> [[SHUF4]], <1 x i32> poison, <2 x i32> zeroinitializer
+// SPIRV-NEXT: shufflevector <2 x i32> %4, <2 x i32> [[SHUF5]], <3 x i32> <i32 0, i32 1, i32 2>
+// SPIRV-NEXT: shufflevector <2 x i32> %5, <2 x i32> [[SHUF6]], <3 x i32> <i32 0, i32 1, i32 2>
+uint3 test_vector3(double3 D) {
uint3 A, B;
asuint(D, A, B);
return A + B;
}
+
+// CHECK: define {{.*}} <4 x i32> {{.*}}test_vector4{{.*}}(<4 x double> {{.*}} [[VALD:%.*]])
+// CHECK: [[VALRET:%.*]] = {{.*}} call { <4 x i32>, <4 x i32> } @llvm.dx.splitdouble.v4i32(<4 x double> [[VALD]])
+// CHECK-NEXT: extractvalue { <4 x i32>, <4 x i32> } [[VALRET]], 0
+// CHECK-NEXT: extractvalue { <4 x i32>, <4 x i32> } [[VALRET]], 1
+// SPIRV: define spir_func {{.*}} <4 x i32> {{.*}}test_vector4{{.*}}(<4 x double> {{.*}} [[VALD:%.*]])
+// SPIRV-NOT: @llvm.dx.splitdouble.i32
+// SPIRV: [[REG:%.*]] = load <4 x double>, ptr [[VALD]].addr, align 32
+// SPIRV-NEXT: [[VALRET1:%.*]] = shufflevector <4 x double> [[REG]], <4 x double> poison, <2 x i32> <i32 0, i32 1>
+// SPIRV-NEXT: [[CAST1:%.*]] = bitcast <2 x double> [[VALRET1]] to <4 x i32>
+// SPIRV-NEXT: [[SHUF1:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
+// SPIRV-NEXT: [[SHUF2:%.*]] = shufflevector <4 x i32> [[CAST1]], <4 x i32> poison, <2 x i32> <i32 1, i32 3>
+// SPIRV-NEXT: [[VALRET2:%.*]] = shufflevector <4 x double> [[REG]], <4 x double> poison, <2 x i32> <i32 2, i32 3>
+// SPIRV-NEXT: [[CAST2:%.*]] = bitcast <2 x double> [[VALRET2]] to <4 x i32>
+// SPIRV-NEXT: [[SHUF3:%.*]] = shufflevector <4 x i32> [[CAST2]], <4 x i32> poison, <2 x i32> <i32 0, i32 2>
+// SPIRV-NEXT: [[SHUF4:%.*]] = shufflevector <4 x i32> [[CAST2]], <4 x i32> poison, <2 x i32> <i32 1, i32 3>
+// SPIRV-NEXT: shufflevector <2 x i32> [[SHUF1]], <2 x i32> [[SHUF3]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+// SPIRV-NEXT: shufflevector <2 x i32> [[SHUF2]], <2 x i32> [[SHUF4]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+uint4 test_vector4(double4 D) {
+ uint4 A, B;
+ asuint(D, A, B);
+ return A + B;
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 34ac423e278bba..91e5f9ab60bfb0 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -94,5 +94,5 @@ def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>
def int_dx_splitdouble : DefaultAttrsIntrinsic<
[llvm_anyint_ty, LLVMMatchType<0>],
[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>],
- [IntrNoMem, IntrWillReturn]>;
+ [IntrNoMem]>;
}
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 338cc546348b8d..68ae5de06423c2 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -47,7 +47,7 @@ def ResRetInt32Ty : DXILOpParamType;
def HandleTy : DXILOpParamType;
def ResBindTy : DXILOpParamType;
def ResPropsTy : DXILOpParamType;
-def ResSplitDoubleTy : DXILOpParamType;
+def SplitDoubleTy : DXILOpParamType;
class DXILOpClass;
@@ -783,7 +783,7 @@ def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> {
def SplitDouble : DXILOp<102, splitDouble> {
let Doc = "Splits a double into 2 uints";
let arguments = [OverloadTy];
- let result = ResSplitDoubleTy;
+ let result = SplitDoubleTy;
let overloads = [Overloads<DXIL1_0, [DoubleTy]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 982d7849d9bb8b..5d5bb3eacace25 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -229,7 +229,7 @@ static StructType *getResPropsType(LLVMContext &Context) {
return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties");
}
-static StructType *getResSplitDoubleType(LLVMContext &Context) {
+static StructType *getSplitDoubleType(LLVMContext &Context) {
if (auto *ST = StructType::getTypeByName(Context, "dx.types.splitdouble"))
return ST;
Type *Int32Ty = Type::getInt32Ty(Context);
@@ -273,8 +273,8 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
return getResBindType(Ctx);
case OpParamType::ResPropsTy:
return getResPropsType(Ctx);
- case OpParamType::ResSplitDoubleTy:
- return getResSplitDoubleType(Ctx);
+ case OpParamType::SplitDoubleTy:
+ return getSplitDoubleType(Ctx);
}
llvm_unreachable("Invalid parameter kind");
return nullptr;
@@ -476,8 +476,8 @@ StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
return ::getResRetType(ElementTy);
}
-StructType *DXILOpBuilder::getResSplitDoubleType(LLVMContext &Context) {
- return ::getResSplitDoubleType(Context);
+StructType *DXILOpBuilder::getSplitDoubleType(LLVMContext &Context) {
+ return ::getSplitDoubleType(Context);
}
StructType *DXILOpBuilder::getHandleType() {
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h
index 8b1e87c283146c..df5a0240870f4a 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.h
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h
@@ -51,7 +51,7 @@ class DXILOpBuilder {
StructType *getResRetType(Type *ElementTy);
/// Get the `%dx.types.splitdouble` type.
- StructType *getResSplitDoubleType(LLVMContext &Context);
+ StructType *getSplitDoubleType(LLVMContext &Context);
/// Get the `%dx.types.Handle` type.
StructType *getHandleType();
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 344f7bb517c2bc..7eb9b1bb9660fe 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -23,9 +23,7 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/InitializePasses.h"
-#include "llvm/Object/Error.h"
#include "llvm/Pass.h"
-#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "dxil-op-lower"
@@ -131,6 +129,30 @@ class OpLowerer {
});
}
+ [[nodiscard]] bool replaceFunctionWithNamedStructOp(
+ Function &F, dxil::OpCode DXILOp, Type *NewRetTy,
+ llvm::function_ref<Error(CallInst *CI, CallInst *Op)> ReplaceUses) {
+ bool IsVectorArgExpansion = isVectorArgExpansion(F);
+ return replaceFunction(F, [&](CallInst *CI) -> Error {
+ SmallVector<Value *> Args;
+ OpBuilder.getIRB().SetInsertPoint(CI);
+ if (IsVectorArgExpansion) {
+ SmallVector<Value *> NewArgs = argVectorFlatten(CI, OpBuilder.getIRB());
+ Args.append(NewArgs.begin(), NewArgs.end());
+ } else
+ Args.append(CI->arg_begin(), CI->arg_end());
+
+ Expected<CallInst *> OpCall =
+ OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), NewRetTy);
+ if (Error E = OpCall.takeError())
+ return E;
+ if (Error E = ReplaceUses(CI, *OpCall))
+ return E;
+
+ return Error::success();
+ });
+ }
+
/// Create a cast between a `target("dx")` type and `dx.types.Handle`, which
/// is intended to be removed by the end of lowering. This is used to allow
/// lowering of ops which need to change their return or argument types in a
@@ -267,20 +289,17 @@ class OpLowerer {
}
Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) {
- IRBuilder<> &IRB = OpBuilder.getIRB();
-
for (Use &U : make_early_inc_range(Intrin->uses())) {
if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
if (EVI->getNumIndices() != 1)
return createStringError(std::errc::invalid_argument,
"Splitdouble has only 2 elements");
-
- size_t IndexVal = EVI->getIndices()[0];
-
- auto *OpEVI = IRB.CreateExtractValue(Op, IndexVal);
- EVI->replaceAllUsesWith(OpEVI);
- EVI->eraseFromParent();
+ EVI->setOperand(0, Op);
+ } else {
+ return make_error<StringError>(
+ "Splitdouble use is not ExtractValueInst",
+ inconvertibleErrorCode());
}
}
@@ -486,33 +505,6 @@ class OpLowerer {
});
}
- [[nodiscard]] bool lowerSplitDouble(Function &F) {
- IRBuilder<> &IRB = OpBuilder.getIRB();
- return replaceFunction(F, [&](CallInst *CI) -> Error {
- IRB.SetInsertPoint(CI);
-
- Value *Arg0 = CI->getArgOperand(0);
-
- if (Arg0->getType()->isVectorTy()) {
- return make_error<StringError>(
- "splitdouble doesn't support lowering vector types.",
- inconvertibleErrorCode());
- }
-
- Type *NewRetTy = OpBuilder.getResSplitDoubleType(M.getContext());
-
- std::array<Value *, 1> Args{Arg0};
- Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
- OpCode::SplitDouble, Args, CI->getName(), NewRetTy);
- if (Error E = OpCall.takeError())
- return E;
- if (Error E = replaceSplitDoubleCallUsages(CI, *OpCall))
- return E;
-
- return Error::success();
- });
- }
-
bool lowerIntrinsics() {
bool Updated = false;
bool HasErrors = false;
@@ -542,7 +534,12 @@ class OpLowerer {
HasErrors |= lowerTypedBufferStore(F);
break;
case Intrinsic::dx_splitdouble:
- HasErrors |= lowerSplitDouble(F);
+ HasErrors |= replaceFunctionWithNamedStructOp(
+ F, OpCode::SplitDouble,
+ OpBuilder.getSplitDoubleType(M.getContext()),
+ [&](CallInst *CI, CallInst *Op) {
+ return replaceSplitDoubleCallUsages(CI, Op);
+ });
break;
}
Updated = true;
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 8ea31401121bce..231afd8ae3eeaf 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -1,38 +1,39 @@
-//===- DirectXTargetTransformInfo.cpp - DirectX TTI ---------------*- C++
-//-*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-///
-//===----------------------------------------------------------------------===//
-
-#include "DirectXTargetTransformInfo.h"
-#include "llvm/IR/Intrinsics.h"
-#include "llvm/IR/IntrinsicsDirectX.h"
-
-using namespace llvm;
-
-bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
- unsigned ScalarOpdIdx) {
- switch (ID) {
- case Intrinsic::dx_wave_readlane:
- return ScalarOpdIdx == 1;
- default:
- return false;
- }
-}
-
-bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
- Intrinsic::ID ID) const {
- switch (ID) {
- case Intrinsic::dx_frac:
- case Intrinsic::dx_rsqrt:
- case Intrinsic::dx_wave_readlane:
- return true;
- default:
- return false;
- }
-}
+//===- DirectXTargetTransformInfo.cpp - DirectX TTI ---------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+//===----------------------------------------------------------------------===//
+
+#include "DirectXTargetTransformInfo.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
+
+using namespace llvm;
+
+bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
+ unsigned ScalarOpdIdx) {
+ switch (ID) {
+ case Intrinsic::dx_wave_readlane:
+ return ScalarOpdIdx == 1;
+ default:
+ return false;
+ }
+}
+
+bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
+ Intrinsic::ID ID) const {
+ switch (ID) {
+ case Intrinsic::dx_frac:
+ case Intrinsic::dx_rsqrt:
+ case Intrinsic::dx_wave_readlane:
+ case Intrinsic::dx_splitdouble:
+ return true;
+ default:
+ return false;
+ }
+}
diff --git a/llvm/test/CodeGen/DirectX/splitdouble_error.ll b/llvm/test/CodeGen/DirectX/splitdouble_error.ll
index acfd52b24c9cc3..d671660f6b2aa9 100644
--- a/llvm/test/CodeGen/DirectX/splitdouble_error.ll
+++ b/llvm/test/CodeGen/DirectX/splitdouble_error.ll
@@ -1,12 +1,11 @@
-; RUN: not opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
; DXIL operation splitdouble doesn't support vector types.
-; CHECK: in function test_vector_double_split
-; CHECK-SAME: splitdouble doesn't support lowering vector types.
+; XFAIL: *
define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %D) local_unnamed_addr {
entry:
- %hlsl.splitdouble = tail call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %D)
+ %hlsl.splitdouble = call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %D)
%0 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.splitdouble, 0
%1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.splitdouble, 1
%add = add <3 x i32> %0, %1
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
index c057042c0d142e..41078fc7970b85 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble.ll
@@ -3,14 +3,48 @@
; Make sure lowering is correctly generating spirv code.
+; CHECK-DAG: %[[#double:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#scalar_function:]] = OpTypeFunction %[[#int_32]] %[[#double]]
+; CHECK-DAG: %[[#vec_2_int_32:]] = OpTypeVector %[[#int_32]] 2
+; CHECK-DAG: %[[#vec_4_int_32:]] = OpTypeVector %[[#int_32]] 4
+; CHECK-DAG: %[[#vec_3_int_32:]] = OpTypeVector %[[#int_32]] 3
+; CHECK-DAG: %[[#vec_3_double:]] = OpTypeVector %[[#double]] 3
+; CHECK-DAG: %[[#vector_function:]] = OpTypeFunction %[[#vec_3_int_32]] %[[#vec_3_double]]
+; CHECK-DAG: %[[#vec_2_double:]] = OpTypeVector %[[#double]] 2
+
+
define spir_func noundef i32 @test_scalar(double noundef %D) local_unnamed_addr {
entry:
- ; CHECK: %[[#]] = OpBitcast %[[#]] %[[#]]
+ ; CHECK: %[[#]] = OpFunction %[[#int_32]] None %[[#scalar_function]]
+ ; CHECK: %[[#param:]] = OpFunctionParameter %[[#double]]
+ ; CHECK: %[[#bitcast:]] = OpBitcast %[[#vec_2_int_32]] %[[#param]]
%0 = bitcast double %D to <2 x i32>
- ; CHECK: %[[#]] = OpCompositeExtract %[[#]] %[[#]] 0
+ ; CHECK: %[[#]] = OpCompositeExtract %[[#int_32:]] %[[#bitcast]] 0
%1 = extractelement <2 x i32> %0, i64 0
- ; CHECK: %[[#]] = OpCompositeExtract %[[#]] %[[#]] 1
+ ; CHECK: %[[#]] = OpCompositeExtract %[[#int_32:]] %[[#bitcast]] 1
%2 = extractelement <2 x i32> %0, i64 1
%add = add i32 %1, %2
ret i32 %add
}
+
+
+define spir_func noundef <3 x i32> @test_vector(<3 x double> noundef %D) local_unnamed_addr {
+entry:
+ ; CHECK: %[[#]] = OpFunction %[[#vec_3_int_32]] None %[[#vector_function]]
+ ; CHECK: %[[#param:]] = OpFunctionParameter %[[#vec_3_double]]
+ ; CHECK: %[[#shuf1:]] = OpVectorShuffle %[[#vec_2_double]] %[[#param]] %[[#]] 0 1
+ %0 = shufflevector <3 x double> %D, <3 x double> poison, <2 x i32> <i32 0, i32 1>
+ ; CHECK: %[[#shuf2:]] = OpVectorShuffle %[[#vec_2_double]] %[[#param]] %[[#]] 2 0
+ %1 = shufflevector <3 x double> %D, <3 x double> poison, <2 x i32> <i32 2, i32 0>
+ ; CHECK: %[[#cast1:]] = OpBitcast %[[#vec_4_int_32]] %[[#shuf1]]
+ %2 = bitcast <2 x double> %0 to <4 x i32>
+ ; CHECK: %[[#cast2:]] = OpBitcast %[[#vec_4_int_32]] %[[#shuf2]]
+ %3 = bitcast <2 x double> %1 to <4 x i32>
+ ; CHECK: %[[#]] = OpVectorShuffle %[[#vec_3_int_32]] %[[#cast1]] %[[#cast2]] 0 2 4
+ %4 = shufflevector <4 x i32> %2, <4 x i32> %3, <3 x i32> <i32 0, i32 2, i32 4>
+ ; CHECK: %[[#]] = OpVectorShuffle %[[#vec_3_int_32]] %[[#cast1]] %[[#cast2]] 1 3 5
+ %5 = shufflevector <4 x i32> %2, <4 x i32> %3, <3 x i32> <i32 1, i32 3, i32 5>
+ %add = add <3 x i32> %4, %5
+ ret <3 x i32> %add
+}
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll
deleted file mode 100644
index 58bd5a046ff3d1..00000000000000
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/splitdouble_vector.ll
+++ /dev/null
@@ -1,14 +0,0 @@
-; RUN: opt -S -scalarizer -mtriple=spirv-vulkan-library %s 2>&1 | llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown -o - | FileCheck %s
-
-; SPIRV lowering for splitdouble should relly on the scalarizer.
-
-define spir_func noundef <3 x i32> @test_vector(<3 x double> noundef %D) local_unnamed_addr {
-entry:
- ; CHECK-COUNT-3: %[[#]] = OpBitcast %[[#]] %[[#]]
- ; CHECK-COUNT-3: %[[#]] = OpCompositeExtract %[[#]] %[[#]] [[0-2]]
- %0 = bitcast <3 x double> %D to <6 x i32>
- %1 = shufflevector <6 x i32> %0, <6 x i32> poison, <3 x i32> <i32 0, i32 2, i32 4>
- %2 = shufflevector <6 x i32> %0, <6 x i32> poison, <3 x i32> <i32 1, i32 3, i32 5>
- %add = add <3 x i32> %1, %2
- ret <3 x i32> %add
-}
More information about the cfe-commits
mailing list