[clang] [llvm] adding clang codegen (PR #109331)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 23 17:50:22 PDT 2024


https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/109331

>From 13a095ca2671cd69b120d9c394831b9ba8e20a50 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Thu, 19 Sep 2024 00:13:51 +0000
Subject: [PATCH 1/3] Codegen builtin

---
 clang/include/clang/Basic/Builtins.td         |  6 ++
 clang/lib/CodeGen/CGBuiltin.cpp               | 38 +++++++++++++
 clang/lib/CodeGen/CGCall.cpp                  |  5 ++
 clang/lib/CodeGen/CGExpr.cpp                  | 15 ++++-
 clang/lib/CodeGen/CodeGenFunction.h           | 10 +++-
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      | 20 +++++++
 clang/lib/Sema/SemaHLSL.cpp                   | 56 ++++++++++++++++---
 .../builtins/asuint-splitdouble.hlsl          | 10 ++++
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  5 ++
 llvm/lib/Target/DirectX/DXIL.td               |  1 +
 .../Target/DirectX/DXILIntrinsicExpansion.cpp | 13 +++++
 11 files changed, 166 insertions(+), 13 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 8c5d7ad763bf97..b38957f6e3f15d 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4788,6 +4788,12 @@ def HLSLStep: LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLAsUintSplitDouble: LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_asuint_splitdouble"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e2711f1ba70239..768cfb7f3b30f6 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18824,6 +18824,44 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         retType, CGM.getHLSLRuntime().getSignIntrinsic(),
         ArrayRef<Value *>{Op0}, nullptr, "hlsl.sign");
   }
+  // This should only be called when targeting DXIL
+  case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+
+    assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
+            E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
+            E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
+           "asuint operands types mismatch");
+
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
+    const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
+
+    CallArgList Args;
+    LValue Op1TmpLValue = EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+    LValue Op2TmpLValue = EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+
+    llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty);
+    if (Op0->getType()->isVectorTy()) {
+      auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
+
+      llvm::VectorType *i32VecTy = llvm::VectorType::get(
+          Int32Ty, ElementCount::getFixed(XVecTy->getNumElements()));
+
+      retType = llvm::StructType::get(i32VecTy, i32VecTy);
+    }
+
+    CallInst *CI =
+        Builder.CreateIntrinsic(retType, llvm::Intrinsic::dx_asuint_splitdouble,
+                                {Op0}, nullptr, "hlsl.asuint");
+
+    Value *arg0 = Builder.CreateExtractValue(CI, 0);
+    Value *arg1 = Builder.CreateExtractValue(CI, 1);
+
+    Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+    auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
+    EmitWritebacks(*this, Args);
+    return s;
+  }
   }
   return nullptr;
 }
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 4ae981e4013e9c..096bbafa4cc694 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4681,6 +4681,11 @@ void CallArg::copyInto(CodeGenFunction &CGF, Address Addr) const {
   IsUsed = true;
 }
 
+void CodeGenFunction::EmitWritebacks(CodeGenFunction &CGF,
+                                     const CallArgList &args) {
+  emitWritebacks(CGF, args);
+}
+
 void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
                                   QualType type) {
   DisableDebugLocationUpdates Dis(*this, E);
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 35b5daaf6d4b55..d53aecbf9f4741 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -19,6 +19,7 @@
 #include "CGObjCRuntime.h"
 #include "CGOpenMPRuntime.h"
 #include "CGRecordLayout.h"
+#include "CGValue.h"
 #include "CodeGenFunction.h"
 #include "CodeGenModule.h"
 #include "ConstantEmitter.h"
@@ -28,6 +29,7 @@
 #include "clang/AST/DeclObjC.h"
 #include "clang/AST/NSAPI.h"
 #include "clang/AST/StmtVisitor.h"
+#include "clang/AST/Type.h"
 #include "clang/Basic/Builtins.h"
 #include "clang/Basic/CodeGenOptions.h"
 #include "clang/Basic/SourceManager.h"
@@ -5458,9 +5460,8 @@ LValue CodeGenFunction::EmitOpaqueValueLValue(const OpaqueValueExpr *e) {
   return getOrCreateOpaqueLValueMapping(e);
 }
 
-void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
-                                         CallArgList &Args, QualType Ty) {
-
+std::pair<LValue, LValue>
+CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) {
   // Emitting the casted temporary through an opaque value.
   LValue BaseLV = EmitLValue(E->getArgLValue());
   OpaqueValueMappingData::bind(*this, E->getOpaqueArgLValue(), BaseLV);
@@ -5474,6 +5475,13 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
                                TempLV);
 
   OpaqueValueMappingData::bind(*this, E->getCastedTemporary(), TempLV);
+  return std::make_pair(BaseLV, TempLV);
+}
+
+LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
+                                           CallArgList &Args, QualType Ty) {
+
+  auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty);
 
   llvm::Value *Addr = TempLV.getAddress().getBasePointer();
   llvm::Type *ElTy = ConvertTypeForMem(TempLV.getType());
@@ -5486,6 +5494,7 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
   Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(),
                     LifetimeSize);
   Args.add(RValue::get(TmpAddr, *this), Ty);
+  return TempLV;
 }
 
 LValue
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 2df17e83bae2ee..7de6003b2c0d66 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4296,8 +4296,11 @@ class CodeGenFunction : public CodeGenTypeCache {
   LValue EmitCastLValue(const CastExpr *E);
   LValue EmitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *E);
   LValue EmitOpaqueValueLValue(const OpaqueValueExpr *e);
-  void EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
-                          QualType Ty);
+
+  std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E,
+                                                  QualType Ty);
+  LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
+                            QualType Ty);
 
   Address EmitExtVectorElementLValue(LValue V);
 
@@ -5145,6 +5148,9 @@ class CodeGenFunction : public CodeGenTypeCache {
                            SourceLocation ArgLoc, AbstractCallee AC,
                            unsigned ParmNum);
 
+  /// EmitWriteback - Emit callbacks for function.
+  void EmitWritebacks(CodeGenFunction &CGF, const CallArgList &args);
+
   /// EmitCallArg - Emit a single call argument.
   void EmitCallArg(CallArgList &args, const Expr *E, QualType ArgType);
 
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 6cd6a2caf19994..d48b60bab16d51 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -423,6 +423,26 @@ template <typename T> _HLSL_INLINE uint asuint(T F) {
   return __detail::bit_cast<uint, T>(F);
 }
 
+//===----------------------------------------------------------------------===//
+// asuint splitdouble builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn void asuint(double D, out uint lowbits, out int highbits)
+/// \brief Split and interprets the lowbits and highbits of double D into uints.
+/// \param D The input double.
+/// \param lowbits The output lowbits of D.
+/// \param highbits The highbits lowbits D.
+#if __is_target_arch(dxil)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double, out uint, out uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double2, out uint2, out uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double3, out uint3, out uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+void asuint(double4, out uint4, out uint4);
+#endif
+
 //===----------------------------------------------------------------------===//
 // atan builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 03b7c2edb605fe..b3fffed7db3360 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1467,18 +1467,27 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
   return true;
 }
 
+bool CheckArgTypeIsCorrect(
+    Sema *S, Expr *Arg, QualType ExpectedType,
+    llvm::function_ref<bool(clang::QualType PassedType)> Check) {
+  QualType PassedType = Arg->getType();
+  if (Check(PassedType)) {
+    if (auto *VecTyA = PassedType->getAs<VectorType>())
+      ExpectedType = S->Context.getVectorType(
+          ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
+    S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
+        << PassedType << ExpectedType << 1 << 0 << 0;
+    return true;
+  }
+  return false;
+}
+
 bool CheckArgsTypesAreCorrect(
     Sema *S, CallExpr *TheCall, QualType ExpectedType,
     llvm::function_ref<bool(clang::QualType PassedType)> Check) {
   for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
-    QualType PassedType = TheCall->getArg(i)->getType();
-    if (Check(PassedType)) {
-      if (auto *VecTyA = PassedType->getAs<VectorType>())
-        ExpectedType = S->Context.getVectorType(
-            ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
-      S->Diag(TheCall->getArg(0)->getBeginLoc(),
-              diag::err_typecheck_convert_incompatible)
-          << PassedType << ExpectedType << 1 << 0 << 0;
+    Expr *Arg = TheCall->getArg(i);
+    if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
       return true;
     }
   }
@@ -1762,6 +1771,37 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+    if (SemaRef.checkArgCount(TheCall, 3))
+      return true;
+
+    // Expr *Op0 = TheCall->getArg(0);
+
+    // auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool {
+    //   return !PassedType->isDoubleType();
+    // };
+
+    // if (CheckArgTypeIsCorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
+    //                           CheckIsNotDouble)) {
+    //   return true;
+    // }
+
+    // Expr *Op1 = TheCall->getArg(1);
+    // Expr *Op2 = TheCall->getArg(2);
+
+    // auto CheckIsNotUint = [](clang::QualType PassedType) -> bool {
+    //   return !PassedType->isUnsignedIntegerType();
+    // };
+
+    // if (CheckArgTypeIsCorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
+    //                           CheckIsNotUint) ||
+    //     CheckArgTypeIsCorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
+    //                           CheckIsNotUint)) {
+    //   return true;
+    // }
+
+    break;
+  }
   case Builtin::BI__builtin_elementwise_acos:
   case Builtin::BI__builtin_elementwise_asin:
   case Builtin::BI__builtin_elementwise_atan:
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
new file mode 100644
index 00000000000000..e359354dc3a6df
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
@@ -0,0 +1,10 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O0 -o - | FileCheck %s
+
+// CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}} 
+// CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}}
+// CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]])
+float fn(double D) {
+  uint A, B;
+  asuint(D, A, B);
+  return A + B;
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 3ce7b8b987ef86..d8092397881550 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -88,4 +88,9 @@ def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
 def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
 def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
+
+def int_dx_asuint_splitdouble : DefaultAttrsIntrinsic<
+    [llvm_anyint_ty, LLVMMatchType<0>], 
+    [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], 
+    [IntrNoMem, IntrWillReturn]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 9aa0af3e3a6b17..06c52da5fc07c8 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -778,6 +778,7 @@ def FlattenedThreadIdInGroup :  DXILOp<96, flattenedThreadIdInGroup> {
   let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
+//
 
 def AnnotateHandle : DXILOp<217, annotateHandle> {
   let Doc = "annotate handle with resource properties";
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index dd73b895b14d37..2bc9ebc962e71a 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -12,6 +12,7 @@
 
 #include "DXILIntrinsicExpansion.h"
 #include "DirectX.h"
+#include "llvm-c/Core.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/DXILResource.h"
@@ -346,6 +347,15 @@ static Value *expandStepIntrinsic(CallInst *Orig) {
   return Builder.CreateSelect(Cond, Zero, One);
 }
 
+// static Value *expandSplitdoubleIntrinsic(CallInst *Orig) {
+//   Value *X = Orig->getOperand(0);
+//   Type *Ty = X->getType();
+//   IRBuilder<> Builder(Orig);
+
+//   Builder.CreateIntrinsic()
+
+// }
+
 static Intrinsic::ID getMaxForClamp(Type *ElemTy,
                                     Intrinsic::ID ClampIntrinsic) {
   if (ClampIntrinsic == Intrinsic::dx_uclamp)
@@ -459,6 +469,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
     break;
   case Intrinsic::dx_step:
     Result = expandStepIntrinsic(Orig);
+    break;
+    // case Intrinsic::dx_asuint_splitdouble:
+    //   Result = expandSplitdoubleIntrinsic(Orig);
   }
   if (Result) {
     Orig->replaceAllUsesWith(Result);

>From f3a9246f6cd447119cc825dd4539ce862d313ae2 Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Mon, 23 Sep 2024 21:19:12 +0000
Subject: [PATCH 2/3] adding vector case for splitdouble

---
 clang/lib/CodeGen/CGBuiltin.cpp               | 62 ++++++++++++++-----
 clang/lib/CodeGen/CGExpr.cpp                  |  8 ++-
 clang/lib/CodeGen/CodeGenFunction.h           |  4 +-
 .../builtins/asuint-splitdouble.hlsl          |  4 +-
 4 files changed, 57 insertions(+), 21 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 768cfb7f3b30f6..9b7cf1568b5754 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -34,12 +34,14 @@
 #include "clang/Frontend/FrontendDiagnostic.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/FloatingPointMode.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/InlineAsm.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsAArch64.h"
 #include "llvm/IR/IntrinsicsAMDGPU.h"
@@ -67,6 +69,7 @@
 #include "llvm/TargetParser/X86TargetParser.h"
 #include <optional>
 #include <sstream>
+#include <utility>
 
 using namespace clang;
 using namespace CodeGen;
@@ -18836,29 +18839,60 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
     const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
 
+    auto emitSplitDouble =
+        [](CGBuilderTy *Builder, llvm::Value *arg,
+           llvm::Type *retType) -> std::pair<Value *, Value *> {
+      CallInst *CI = Builder->CreateIntrinsic(
+          retType, llvm::Intrinsic::dx_asuint_splitdouble, {arg}, nullptr,
+          "hlsl.asuint");
+
+      Value *arg0 = Builder->CreateExtractValue(CI, 0);
+      Value *arg1 = Builder->CreateExtractValue(CI, 1);
+
+      return std::make_pair(arg0, arg1);
+    };
+
     CallArgList Args;
-    LValue Op1TmpLValue = EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
-    LValue Op2TmpLValue = EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
+    auto [Op1BaseLValue, Op1TmpLValue] =
+        EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
+    auto [Op2BaseLValue, Op2TmpLValue] =
+        EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
 
     llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty);
-    if (Op0->getType()->isVectorTy()) {
-      auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>();
 
-      llvm::VectorType *i32VecTy = llvm::VectorType::get(
-          Int32Ty, ElementCount::getFixed(XVecTy->getNumElements()));
+    if (!Op0->getType()->isVectorTy()) {
+      auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType);
+
+      Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
+      auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
 
-      retType = llvm::StructType::get(i32VecTy, i32VecTy);
+      EmitWritebacks(*this, Args);
+      return s;
     }
 
-    CallInst *CI =
-        Builder.CreateIntrinsic(retType, llvm::Intrinsic::dx_asuint_splitdouble,
-                                {Op0}, nullptr, "hlsl.asuint");
+    auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
+
+    llvm::VectorType *i32VecTy = llvm::VectorType::get(
+        Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
 
-    Value *arg0 = Builder.CreateExtractValue(CI, 0);
-    Value *arg1 = Builder.CreateExtractValue(CI, 1);
+    std::pair<Value *, Value *> inserts = std::make_pair(nullptr, nullptr);
+
+    for (uint64_t idx = 0; idx < Op0VecTy->getNumElements(); idx++) {
+      Value *op = Builder.CreateExtractElement(Op0, idx);
+
+      auto [arg0, arg1] = emitSplitDouble(&Builder, op, retType);
+
+      if (idx == 0) {
+        inserts.first = Builder.CreateInsertElement(i32VecTy, arg0, idx);
+        inserts.second = Builder.CreateInsertElement(i32VecTy, arg1, idx);
+      } else {
+        inserts.first = Builder.CreateInsertElement(inserts.first, arg0, idx);
+        inserts.second = Builder.CreateInsertElement(inserts.second, arg0, idx);
+      }
+    }
 
-    Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
-    auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
+    Builder.CreateStore(inserts.first, Op1TmpLValue.getAddress());
+    auto *s = Builder.CreateStore(inserts.second, Op2TmpLValue.getAddress());
     EmitWritebacks(*this, Args);
     return s;
   }
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index d53aecbf9f4741..5d4c7193fa969b 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -54,6 +54,7 @@
 
 #include <optional>
 #include <string>
+#include <utility>
 
 using namespace clang;
 using namespace CodeGen;
@@ -5478,8 +5479,9 @@ CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) {
   return std::make_pair(BaseLV, TempLV);
 }
 
-LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
-                                           CallArgList &Args, QualType Ty) {
+std::pair<LValue, LValue>
+CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
+                                    QualType Ty) {
 
   auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty);
 
@@ -5494,7 +5496,7 @@ LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
   Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(),
                     LifetimeSize);
   Args.add(RValue::get(TmpAddr, *this), Ty);
-  return TempLV;
+  return std::make_pair(BaseLV, TempLV);
 }
 
 LValue
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 7de6003b2c0d66..7dbb80415d4ec0 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4299,8 +4299,8 @@ class CodeGenFunction : public CodeGenTypeCache {
 
   std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E,
                                                   QualType Ty);
-  LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args,
-                            QualType Ty);
+  std::pair<LValue, LValue> EmitHLSLOutArgExpr(const HLSLOutArgExpr *E,
+                                               CallArgList &Args, QualType Ty);
 
   Address EmitExtVectorElementLValue(LValue V);
 
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
index e359354dc3a6df..4326612db96b0f 100644
--- a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
@@ -3,8 +3,8 @@
 // CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}} 
 // CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}}
 // CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]])
-float fn(double D) {
-  uint A, B;
+float2 fn(double2 D) {
+  uint2 A, B;
   asuint(D, A, B);
   return A + B;
 }

>From bf404c41200b11fdc05bfa2a09989951cafd80bd Mon Sep 17 00:00:00 2001
From: Joao Saffran <jderezende at microsoft.com>
Date: Tue, 24 Sep 2024 00:50:10 +0000
Subject: [PATCH 3/3] adding lowering to dxil

---
 clang/include/clang/Basic/Builtins.td         |  4 +-
 clang/lib/CodeGen/CGBuiltin.cpp               |  6 +--
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  8 ++--
 clang/lib/Sema/SemaHLSL.cpp                   | 40 +++++++++----------
 .../builtins/asuint-splitdouble.hlsl          |  4 +-
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  2 +-
 llvm/lib/Target/DirectX/DXIL.td               | 12 +++++-
 .../Target/DirectX/DXILIntrinsicExpansion.cpp | 25 +++++++-----
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp     | 10 +++++
 llvm/lib/Target/DirectX/DXILOpLowering.cpp    |  1 +
 10 files changed, 70 insertions(+), 42 deletions(-)

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b38957f6e3f15d..4e0566615b5fef 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4788,8 +4788,8 @@ def HLSLStep: LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
-def HLSLAsUintSplitDouble: LangBuiltin<"HLSL_LANG"> {
-  let Spellings = ["__builtin_hlsl_asuint_splitdouble"];
+def HLSLSplitDouble: LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_splitdouble"];
   let Attributes = [NoThrow, Const];
   let Prototype = "void(...)";
 }
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 9b7cf1568b5754..9efe9257e55a1b 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18828,7 +18828,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         ArrayRef<Value *>{Op0}, nullptr, "hlsl.sign");
   }
   // This should only be called when targeting DXIL
-  case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+  case Builtin::BI__builtin_hlsl_splitdouble: {
 
     assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
             E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
@@ -18843,7 +18843,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         [](CGBuilderTy *Builder, llvm::Value *arg,
            llvm::Type *retType) -> std::pair<Value *, Value *> {
       CallInst *CI = Builder->CreateIntrinsic(
-          retType, llvm::Intrinsic::dx_asuint_splitdouble, {arg}, nullptr,
+          retType, llvm::Intrinsic::dx_splitdouble, {arg}, nullptr,
           "hlsl.asuint");
 
       Value *arg0 = Builder->CreateExtractValue(CI, 0);
@@ -18858,7 +18858,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
     auto [Op2BaseLValue, Op2TmpLValue] =
         EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
 
-    llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty);
+    llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
 
     if (!Op0->getType()->isVectorTy()) {
       auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index d48b60bab16d51..5816bb0f6967c5 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -433,13 +433,13 @@ template <typename T> _HLSL_INLINE uint asuint(T F) {
 /// \param lowbits The output lowbits of D.
 /// \param highbits The highbits lowbits D.
 #if __is_target_arch(dxil)
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
 void asuint(double, out uint, out uint);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
 void asuint(double2, out uint2, out uint2);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
 void asuint(double3, out uint3, out uint3);
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble)
 void asuint(double4, out uint4, out uint4);
 #endif
 
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index b3fffed7db3360..48f3dabccc0ec2 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1771,34 +1771,34 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
-  case Builtin::BI__builtin_hlsl_asuint_splitdouble: {
+  case Builtin::BI__builtin_hlsl_splitdouble: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
 
-    // Expr *Op0 = TheCall->getArg(0);
+    Expr *Op0 = TheCall->getArg(0);
 
-    // auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool {
-    //   return !PassedType->isDoubleType();
-    // };
+    auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool {
+      return !PassedType->hasFloatingRepresentation();
+    };
 
-    // if (CheckArgTypeIsCorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
-    //                           CheckIsNotDouble)) {
-    //   return true;
-    // }
+    if (CheckArgTypeIsCorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy,
+                              CheckIsNotDouble)) {
+      return true;
+    }
 
-    // Expr *Op1 = TheCall->getArg(1);
-    // Expr *Op2 = TheCall->getArg(2);
+    Expr *Op1 = TheCall->getArg(1);
+    Expr *Op2 = TheCall->getArg(2);
 
-    // auto CheckIsNotUint = [](clang::QualType PassedType) -> bool {
-    //   return !PassedType->isUnsignedIntegerType();
-    // };
+    auto CheckIsNotUint = [](clang::QualType PassedType) -> bool {
+      return !PassedType->hasUnsignedIntegerRepresentation();
+    };
 
-    // if (CheckArgTypeIsCorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
-    //                           CheckIsNotUint) ||
-    //     CheckArgTypeIsCorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
-    //                           CheckIsNotUint)) {
-    //   return true;
-    // }
+    if (CheckArgTypeIsCorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy,
+                              CheckIsNotUint) ||
+        CheckArgTypeIsCorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy,
+                              CheckIsNotUint)) {
+      return true;
+    }
 
     break;
   }
diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
index 4326612db96b0f..e359354dc3a6df 100644
--- a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl
@@ -3,8 +3,8 @@
 // CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}} 
 // CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}}
 // CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]])
-float2 fn(double2 D) {
-  uint2 A, B;
+float fn(double D) {
+  uint A, B;
   asuint(D, A, B);
   return A + B;
 }
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index d8092397881550..04dd26ea54ca80 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -89,7 +89,7 @@ def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrCon
 def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
 
-def int_dx_asuint_splitdouble : DefaultAttrsIntrinsic<
+def int_dx_splitdouble : DefaultAttrsIntrinsic<
     [llvm_anyint_ty, LLVMMatchType<0>], 
     [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], 
     [IntrNoMem, IntrWillReturn]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 06c52da5fc07c8..004e65d9b896dd 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -47,6 +47,7 @@ def ResRetInt32Ty : DXILOpParamType;
 def HandleTy : DXILOpParamType;
 def ResBindTy : DXILOpParamType;
 def ResPropsTy : DXILOpParamType;
+def ResSplitDoubleTy : DXILOpParamType;
 
 class DXILOpClass;
 
@@ -778,7 +779,16 @@ def FlattenedThreadIdInGroup :  DXILOp<96, flattenedThreadIdInGroup> {
   let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
-//
+
+def SplitDouble :  DXILOp<108, splitDouble> {
+  let Doc = "Splits a double into 2 uints";
+  let LLVMIntrinsic = int_dx_splitdouble;
+  let arguments = [OverloadTy];
+  let result = ResSplitDoubleTy;
+  let overloads = [Overloads<DXIL1_0, [DoubleTy]>];
+  let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
 
 def AnnotateHandle : DXILOp<217, annotateHandle> {
   let Doc = "annotate handle with resource properties";
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 2bc9ebc962e71a..24363d3f00f041 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -17,6 +17,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/DXILResource.h"
 #include "llvm/CodeGen/Passes.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Instruction.h"
@@ -347,14 +348,19 @@ static Value *expandStepIntrinsic(CallInst *Orig) {
   return Builder.CreateSelect(Cond, Zero, One);
 }
 
-// static Value *expandSplitdoubleIntrinsic(CallInst *Orig) {
-//   Value *X = Orig->getOperand(0);
-//   Type *Ty = X->getType();
-//   IRBuilder<> Builder(Orig);
-
-//   Builder.CreateIntrinsic()
+static Value *expandSplitdoubleIntrinsic(CallInst *Orig) {
+  Value *X = Orig->getOperand(0);
+  IRBuilder<> Builder(Orig);
+  Type *Int32Ty = Type::getInt32Ty(Orig->getContext());
+  Type *Ty =  StructType::create({Int32Ty, Int32Ty}, "dx.types.splitdouble");
+  
+  auto *CI =
+      Builder.CreateIntrinsic(Ty, Intrinsic::dx_splitdouble, {X}, nullptr, "splitdouble");
+  CI->setTailCall(Orig->isTailCall());
+  CI->setAttributes(Orig->getAttributes());
+  return CI;
 
-// }
+}
 
 static Intrinsic::ID getMaxForClamp(Type *ElemTy,
                                     Intrinsic::ID ClampIntrinsic) {
@@ -470,8 +476,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
   case Intrinsic::dx_step:
     Result = expandStepIntrinsic(Orig);
     break;
-    // case Intrinsic::dx_asuint_splitdouble:
-    //   Result = expandSplitdoubleIntrinsic(Orig);
+  case Intrinsic::dx_splitdouble:
+    Result = expandSplitdoubleIntrinsic(Orig);
+    break;
   }
   if (Result) {
     Orig->replaceAllUsesWith(Result);
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 7719d6b1079110..8a159369ba65c3 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -229,6 +229,14 @@ static StructType *getResPropsType(LLVMContext &Context) {
   return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties");
 }
 
+static StructType *getResSplitDoubleType(LLVMContext &Context) {
+  if (auto *ST =
+          StructType::getTypeByName(Context, "dx.types.splitdouble"))
+    return ST;
+  Type *Int32Ty = Type::getInt32Ty(Context);
+  return StructType::create({Int32Ty, Int32Ty}, "dx.types.splitdouble");
+}
+
 static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
                                     Type *OverloadTy) {
   switch (Kind) {
@@ -266,6 +274,8 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
     return getResBindType(Ctx);
   case OpParamType::ResPropsTy:
     return getResPropsType(Ctx);
+  case OpParamType::ResSplitDoubleTy:
+    return getResSplitDoubleType(Ctx);
   }
   llvm_unreachable("Invalid parameter kind");
   return nullptr;
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 3ee3ee05563c24..96e958007b60f7 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -489,6 +489,7 @@ class OpLowerer {
       case Intrinsic::dx_typedBufferStore:
         HasErrors |= lowerTypedBufferStore(F);
         break;
+
       }
       Updated = true;
     }



More information about the llvm-commits mailing list