[clang] [HLSL] select scalar overloads for vector conditions (PR #129396)

Chris B via cfe-commits cfe-commits at lists.llvm.org
Sun Mar 9 10:21:02 PDT 2025


https://github.com/llvm-beanz updated https://github.com/llvm/llvm-project/pull/129396

>From 7620f9fac9932a13f1da0468b02c1aeceb212a0b Mon Sep 17 00:00:00 2001
From: Chris Bieneman <chris.bieneman at me.com>
Date: Wed, 19 Feb 2025 17:18:20 -0600
Subject: [PATCH 1/5] [HLSL] select scalar overloads for vector conditions

This PR adds scalar/vector overloads for vector conditions to the
`select` builtin, and updates the sema checking and codegen to allow
scalars to extend to vectors.

Fixes #126570

clang-format
clang-format
'cbieneman/select' on '44f0fe9a2806'.
---
 .../clang/Basic/DiagnosticSemaKinds.td        |  3 +
 clang/lib/CodeGen/CGBuiltin.cpp               |  8 ++
 .../lib/Headers/hlsl/hlsl_alias_intrinsics.h  | 36 +++++++
 clang/lib/Headers/hlsl/hlsl_detail.h          |  5 +
 clang/lib/Sema/SemaHLSL.cpp                   | 56 ++++++-----
 clang/test/CodeGenHLSL/builtins/select.hlsl   | 29 ++++++
 .../test/SemaHLSL/BuiltIns/select-errors.hlsl | 98 +++++--------------
 7 files changed, 135 insertions(+), 100 deletions(-)

diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 21be7c358a61d..2514fb68bf5b0 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12707,6 +12707,9 @@ def err_hlsl_param_qualifier_mismatch :
 def err_hlsl_vector_compound_assignment_truncation : Error<
   "left hand operand of type %0 to compound assignment cannot be truncated "
   "when used with right hand operand of type %1">;
+def err_hlsl_builtin_scalar_vector_mismatch : Error<
+  "%select{all|second and third}0 arguments to %1 must be of scalar or "
+  "vector type with matching scalar element type%diff{: $ vs $|}2,3">;
 
 def warn_hlsl_impcast_vector_truncation : Warning<
   "implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>;
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index b86bb242755be..ba78de049ce96 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -19836,6 +19836,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         RValFalse.isScalar()
             ? RValFalse.getScalarVal()
             : RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
+    if (auto *VTy = E->getType()->getAs<VectorType>()) {
+      if (!OpTrue->getType()->isVectorTy())
+        OpTrue =
+            Builder.CreateVectorSplat(VTy->getNumElements(), OpTrue, "splat");
+      if (!OpFalse->getType()->isVectorTy())
+        OpFalse =
+            Builder.CreateVectorSplat(VTy->getNumElements(), OpFalse, "splat");
+    }
 
     Value *SelectVal =
         Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index 7573f6e024167..7a550a58e705c 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -2123,6 +2123,42 @@ template <typename T, int Sz>
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
 vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>);
 
+
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, T TrueVal,
+///                         vector<T,Sz> FalseVals)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVal The scalar value to splat from when conditions are true.
+/// \param FalseVals The vector values are chosen from when conditions are
+/// false.
+
+template <typename T, int Sz>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+vector<T, Sz> select(vector<bool, Sz>, T, vector<T, Sz>);
+
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
+///                         T FalseVal)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVals The vector values are chosen from when conditions are true.
+/// \param FalseVal The scalar value to splat from when conditions are false.
+
+template <typename T, int Sz>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, T);
+
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
+///                         T FalseVal)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVal The scalar value to splat from when conditions are true.
+/// \param FalseVal The scalar value to splat from when conditions are false.
+
+template <typename T, int Sz>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+__detail::enable_if_t<__detail::is_arithmetic<T>::Value, vector<T, Sz>> select(
+    vector<bool, Sz>, T, T);
+
 //===----------------------------------------------------------------------===//
 // sin builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 39254a3cc3a0a..086c527614a43 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -97,6 +97,11 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 #endif
 }
 
+template<typename T>
+struct is_arithmetic {
+  static const bool Value = __is_arithmetic(T);
+};
+
 } // namespace __detail
 } // namespace hlsl
 #endif //_HLSL_HLSL_DETAILS_H_
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index aff349a932eec..2e6a333f3d768 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2225,40 +2225,48 @@ static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
 static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() == 3);
   Expr *Arg1 = TheCall->getArg(1);
+  QualType Arg1Ty = Arg1->getType();
   Expr *Arg2 = TheCall->getArg(2);
-  if (!Arg1->getType()->isVectorType()) {
-    S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type)
-        << "Second" << TheCall->getDirectCallee() << Arg1->getType()
+  QualType Arg2Ty = Arg2->getType();
+
+  QualType Arg1ScalarTy = Arg1Ty;
+  if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
+    Arg1ScalarTy = VTy->getElementType();
+
+  QualType Arg2ScalarTy = Arg2Ty;
+  if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
+    Arg2ScalarTy = VTy->getElementType();
+
+  if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
+    S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
+        << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
+
+  QualType Arg0Ty = TheCall->getArg(0)->getType();
+  unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
+  unsigned Arg1Length = Arg1Ty->isVectorType()
+                            ? Arg1Ty->getAs<VectorType>()->getNumElements()
+                            : 0;
+  unsigned Arg2Length = Arg2Ty->isVectorType()
+                            ? Arg2Ty->getAs<VectorType>()->getNumElements()
+                            : 0;
+  if (Arg1Length > 0 && Arg0Length != Arg1Length) {
+    S->Diag(TheCall->getBeginLoc(),
+            diag::err_typecheck_vector_lengths_not_equal)
+        << Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
         << Arg1->getSourceRange();
     return true;
   }
 
-  if (!Arg2->getType()->isVectorType()) {
-    S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type)
-        << "Third" << TheCall->getDirectCallee() << Arg2->getType()
-        << Arg2->getSourceRange();
-    return true;
-  }
-
-  if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
+  if (Arg2Length > 0 && Arg0Length != Arg2Length) {
     S->Diag(TheCall->getBeginLoc(),
-            diag::err_typecheck_call_different_arg_types)
-        << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
+            diag::err_typecheck_vector_lengths_not_equal)
+        << Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
         << Arg2->getSourceRange();
     return true;
   }
 
-  // caller has checked that Arg0 is a vector.
-  // check all three args have the same length.
-  if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
-      Arg1->getType()->getAs<VectorType>()->getNumElements()) {
-    S->Diag(TheCall->getBeginLoc(),
-            diag::err_typecheck_vector_lengths_not_equal)
-        << TheCall->getArg(0)->getType() << Arg1->getType()
-        << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
-    return true;
-  }
-  TheCall->setType(Arg1->getType());
+  TheCall->setType(
+      S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
   return false;
 }
 
diff --git a/clang/test/CodeGenHLSL/builtins/select.hlsl b/clang/test/CodeGenHLSL/builtins/select.hlsl
index cade938b71a2b..196b8a90cd877 100644
--- a/clang/test/CodeGenHLSL/builtins/select.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/select.hlsl
@@ -52,3 +52,32 @@ int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) {
 int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) {
   return select(cond0, tVals, fVals);
 }
+
+// CHECK-LABEL: test_select_vector_scalar_vector
+// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
+// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> {{%.*}}
+// CHECK: ret <4 x i32> [[SELECT]]
+int4 test_select_vector_scalar_vector(bool4 cond0, int tVal, int4 fVals) {
+  return select(cond0, tVal, fVals);
+}
+
+// CHECK-LABEL: test_select_vector_vector_scalar
+// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
+// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> {{%.*}}, <4 x i32> [[SPLAT1]]
+// CHECK: ret <4 x i32> [[SELECT]]
+int4 test_select_vector_vector_scalar(bool4 cond0, int4 tVals, int fVal) {
+  return select(cond0, tVals, fVal);
+}
+
+// CHECK-LABEL: test_select_vector_scalar_scalar
+// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
+// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SPLAT_SRC2:%.*]] = insertelement <4 x i32> poison, i32 %3, i64 0
+// CHECK: [[SPLAT2:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC2]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> [[SPLAT2]]
+// CHECK: ret <4 x i32> [[SELECT]]
+int4 test_select_vector_scalar_scalar(bool4 cond0, int tVal, int fVal) {
+  return select(cond0, tVal, fVal);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
index 34b5fb6d54cd5..b445cedcba074 100644
--- a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
@@ -1,119 +1,65 @@
-// RUN: %clang_cc1 -finclude-default-header
-// -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only
-// -disable-llvm-passes -verify -verify-ignore-unexpected
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
 
-int test_no_arg() {
-  return select();
-  // expected-error at -1 {{no matching function for call to 'select'}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template
-  // not viable: requires 3 arguments, but 0 were provided}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 0 were provided}}
-}
-
-int test_too_few_args(bool p0) {
-  return select(p0);
-  // expected-error at -1 {{no matching function for call to 'select'}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 1 was provided}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 1 was provided}}
-}
-
-int test_too_many_args(bool p0, int t0, int f0, int g0) {
-  return select<int>(p0, t0, f0, g0);
-  // expected-error at -1 {{no matching function for call to 'select'}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 4 were provided}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: requires 3 arguments, but 4 were provided}}
-}
 
 int test_select_first_arg_wrong_type(int1 p0, int t0, int f0) {
   return select(p0, t0, f0);
-  // expected-error at -1 {{no matching function for call to 'select'}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: no known conversion from 'vector<int, 1>' (vector of 1 'int' value)
-  // to 'bool' for 1st argument}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could
-  // not match 'vector<T, Sz>' against 'int'}}
 }
 
 int1 test_select_bool_vals_diff_vecs(bool p0, int1 t0, int1 f0) {
   return select<int1>(p0, t0, f0);
-  // expected-warning at -1 {{implicit conversion truncates vector:
-  // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
-  // (vector of 1 'int' value)}}
 }
 
 int2 test_select_vector_vals_not_vecs(bool2 p0, int t0,
                                                int f0) {
   return select(p0, t0, f0);
-  // expected-error at -1 {{no matching function for call to 'select'}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored:
-  // could not match 'vector<T, Sz>' against 'int'}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
-  // viable: no known conversion from 'vector<bool, 2>'
-  // (vector of 2 'bool' values) to 'bool' for 1st argument}}
 }
 
 int1 test_select_vector_vals_wrong_size(bool2 p0, int1 t0, int1 f0) {
-  return select<int,1>(p0, t0, f0); // produce warnings
-  // expected-warning at -1 {{implicit conversion truncates vector:
-  // 'vector<bool, 2>' (vector of 2 'bool' values) to 'vector<bool, 1>'
-  // (vector of 1 'bool' value)}}
-  // expected-warning at -2 {{implicit conversion truncates vector:
-  // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
-  // (vector of 1 'int' value)}}
+  return select<int,1>(p0, t0, f0); // expected-warning{{implicit conversion truncates vector: 'bool2' (aka 'vector<bool, 2>') to 'vector<bool, 1>' (vector of 1 'bool' value)}}
+}
+
+int test_select_no_args() {
+  return __builtin_hlsl_select(); // expected-error{{too few arguments to function call, expected 3, have 0}}
+}
+
+int test_select_builtin_wrong_arg_count(bool p0) {
+  return __builtin_hlsl_select(p0); // expected-error{{too few arguments to function call, expected 3, have 1}}
 }
 
 // __builtin_hlsl_select tests
-int test_select_builtin_wrong_arg_count(bool p0, int t0) {
-  return __builtin_hlsl_select(p0, t0);
-  // expected-error at -1 {{too few arguments to function call, expected 3,
-  // have 2}}
+int test_select_builtin_wrong_arg_count2(bool p0, int t0) {
+  return __builtin_hlsl_select(p0, t0); // expected-error{{too few arguments to function call, expected 3, have 2}}
+}
+
+int test_too_many_args(bool p0, int t0, int f0, int g0) {
+  return __builtin_hlsl_select(p0, t0, f0, g0); // expected-error{{too many arguments to function call, expected 3, have 4}}
 }
 
 // not a bool or a vector of bool. should be 2 errors.
 int test_select_builtin_first_arg_wrong_type1(int p0, int t0, int f0) {
-  return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error at -1 {{passing 'int' to parameter of incompatible type
-  // 'bool'}}
-  // expected-error at -2 {{First argument to __builtin_hlsl_select must be of
-  // vector type}}
-  }
+  return __builtin_hlsl_select(p0, t0, f0); // expected-error{{invalid operand of type 'int' where 'bool' or a vector of such type is required}}
+}
 
 int test_select_builtin_first_arg_wrong_type2(int1 p0, int t0, int f0) {
-  return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error at -1 {{passing 'vector<int, 1>' (vector of 1 'int' value) to
-  // parameter of incompatible type 'bool'}}
-  // expected-error at -2 {{First argument to __builtin_hlsl_select must be of
-  // vector type}}
+  return __builtin_hlsl_select(p0, t0, f0); // expected-error{{invalid operand of type 'int1' (aka 'vector<int, 1>') where 'bool' or a vector of such type is required}}
 }
 
 // if a bool last 2 args are of same type
 int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) {
-  return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error at -1 {{arguments are of different types ('int' vs 'double')}}
+  return __builtin_hlsl_select(p0, t0, f0); // expected-error{{arguments are of different types ('int' vs 'double')}}
 }
 
 // if a vector second arg isnt a vector
 int2 test_select_builtin_second_arg_not_vector(bool2 p0, int t0, int2 f0) {
   return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error at -1 {{Second argument to __builtin_hlsl_select must be of
-  // vector type}}
 }
 
 // if a vector third arg isn't a vector
 int2 test_select_builtin_second_arg_not_vector(bool2 p0, int2 t0, int f0) {
   return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error at -1 {{Third argument to __builtin_hlsl_select must be of
-  // vector type}}
 }
 
 // if vector last 2 aren't same type (so both are vectors but wrong type)
-int2 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) {
-  return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error at -1 {{arguments are of different types ('vector<int, [...]>'
-  // vs 'vector<float, [...]>')}}
+int1 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) {
+  return __builtin_hlsl_select(p0, t0, f0); // expected-error{{second and third arguments to __builtin_hlsl_select must be of scalar or vector type with matching scalar element type: 'vector<int, [...]>' vs 'vector<float, [...]>'}}
 }

>From c93ecf3cc4deadbdfe735734675fdd689d0a1067 Mon Sep 17 00:00:00 2001
From: Chris Bieneman <chris.bieneman at me.com>
Date: Fri, 7 Mar 2025 10:47:48 -0600
Subject: [PATCH 2/5] Added comments to address @spall's feedback and
 clang-format

---
 clang/include/clang/Basic/DiagnosticSemaKinds.td |  7 ++++---
 clang/lib/Headers/hlsl/hlsl_detail.h             |  3 +--
 clang/test/SemaHLSL/BuiltIns/select-errors.hlsl  | 10 +++++-----
 3 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 2514fb68bf5b0..14b0051709625 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12707,9 +12707,10 @@ def err_hlsl_param_qualifier_mismatch :
 def err_hlsl_vector_compound_assignment_truncation : Error<
   "left hand operand of type %0 to compound assignment cannot be truncated "
   "when used with right hand operand of type %1">;
-def err_hlsl_builtin_scalar_vector_mismatch : Error<
-  "%select{all|second and third}0 arguments to %1 must be of scalar or "
-  "vector type with matching scalar element type%diff{: $ vs $|}2,3">;
+def err_hlsl_builtin_scalar_vector_mismatch
+    : Error<
+          "%select{all|second and third}0 arguments to %1 must be of scalar or "
+          "vector type with matching scalar element type%diff{: $ vs $|}2,3">;
 
 def warn_hlsl_impcast_vector_truncation : Warning<
   "implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>;
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 086c527614a43..17a2ea9b54722 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -97,8 +97,7 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 #endif
 }
 
-template<typename T>
-struct is_arithmetic {
+template <typename T> struct is_arithmetic {
   static const bool Value = __is_arithmetic(T);
 };
 
diff --git a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
index b445cedcba074..7000affaef77f 100644
--- a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
@@ -2,16 +2,16 @@
 
 
 int test_select_first_arg_wrong_type(int1 p0, int t0, int f0) {
-  return select(p0, t0, f0);
+  return select(p0, t0, f0); // No diagnostic expected.
 }
 
 int1 test_select_bool_vals_diff_vecs(bool p0, int1 t0, int1 f0) {
-  return select<int1>(p0, t0, f0);
+  return select<int1>(p0, t0, f0); // No diagnostic expected.
 }
 
 int2 test_select_vector_vals_not_vecs(bool2 p0, int t0,
                                                int f0) {
-  return select(p0, t0, f0);
+  return select(p0, t0, f0); // No diagnostic expected.
 }
 
 int1 test_select_vector_vals_wrong_size(bool2 p0, int1 t0, int1 f0) {
@@ -51,12 +51,12 @@ int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) {
 
 // if a vector second arg isnt a vector
 int2 test_select_builtin_second_arg_not_vector(bool2 p0, int t0, int2 f0) {
-  return __builtin_hlsl_select(p0, t0, f0);
+  return __builtin_hlsl_select(p0, t0, f0); // No diagnostic expected.
 }
 
 // if a vector third arg isn't a vector
 int2 test_select_builtin_second_arg_not_vector(bool2 p0, int2 t0, int f0) {
-  return __builtin_hlsl_select(p0, t0, f0);
+  return __builtin_hlsl_select(p0, t0, f0); // No diagnostic expected.
 }
 
 // if vector last 2 aren't same type (so both are vectors but wrong type)

>From 135b1cdce0b7284a68e43873a10a16b389ee49e2 Mon Sep 17 00:00:00 2001
From: Chris Bieneman <chris.bieneman at me.com>
Date: Fri, 7 Mar 2025 16:31:57 -0600
Subject: [PATCH 3/5] Alother slight suffling of hlsl includes

This fixes issues that were causing the new select overloads to fail to
compile.
---
 clang/lib/Headers/CMakeLists.txt              |  1 +
 clang/lib/Headers/hlsl.h                      |  3 +
 clang/lib/Headers/hlsl/hlsl_detail.h          | 58 +--------------
 .../lib/Headers/hlsl/hlsl_intrinsic_helpers.h | 71 +++++++++++++++++++
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  2 +-
 5 files changed, 77 insertions(+), 58 deletions(-)
 create mode 100644 clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h

diff --git a/clang/lib/Headers/CMakeLists.txt b/clang/lib/Headers/CMakeLists.txt
index e5bf8f35f7d52..d26de236998ca 100644
--- a/clang/lib/Headers/CMakeLists.txt
+++ b/clang/lib/Headers/CMakeLists.txt
@@ -87,6 +87,7 @@ set(hlsl_h
 set(hlsl_subdir_files
   hlsl/hlsl_basic_types.h
   hlsl/hlsl_alias_intrinsics.h
+  hlsl/hlsl_intrinsic_helpers.h
   hlsl/hlsl_intrinsics.h
   hlsl/hlsl_detail.h
   )
diff --git a/clang/lib/Headers/hlsl.h b/clang/lib/Headers/hlsl.h
index 6edfd949f2b97..508e83e3a28e7 100644
--- a/clang/lib/Headers/hlsl.h
+++ b/clang/lib/Headers/hlsl.h
@@ -16,7 +16,10 @@
 #pragma clang diagnostic ignored "-Whlsl-dxc-compatability"
 #endif
 
+
+#include "hlsl/hlsl_detail.h"
 #include "hlsl/hlsl_basic_types.h"
+#include "hlsl/hlsl_alias_intrinsics.h"
 #include "hlsl/hlsl_intrinsics.h"
 
 #if defined(__clang__)
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 17a2ea9b54722..c691d85283de4 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -1,4 +1,4 @@
-//===----- detail.h - HLSL definitions for intrinsics ----------===//
+//===----- hlsl_detail.h - HLSL definitions for intrinsics ----------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -9,8 +9,6 @@
 #ifndef _HLSL_HLSL_DETAILS_H_
 #define _HLSL_HLSL_DETAILS_H_
 
-#include "hlsl_alias_intrinsics.h"
-
 namespace hlsl {
 
 namespace __detail {
@@ -43,60 +41,6 @@ constexpr enable_if_t<sizeof(U) == sizeof(T), U> bit_cast(T F) {
   return __builtin_bit_cast(U, F);
 }
 
-constexpr vector<uint, 4> d3d_color_to_ubyte4_impl(vector<float, 4> V) {
-  // Use the same scaling factor used by FXC, and DXC for DXIL
-  // (i.e., 255.001953)
-  // https://github.com/microsoft/DirectXShaderCompiler/blob/070d0d5a2beacef9eeb51037a9b04665716fd6f3/lib/HLSL/HLOperationLower.cpp#L666C1-L697C2
-  // The DXC implementation refers to a comment on the following stackoverflow
-  // discussion to justify the scaling factor: "Built-in rounding, necessary
-  // because of truncation. 0.001953 * 256 = 0.5"
-  // https://stackoverflow.com/questions/52103720/why-does-d3dcolortoubyte4-multiplies-components-by-255-001953f
-  return V.zyxw * 255.001953f;
-}
-
-template <typename T>
-constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
-length_impl(T X) {
-  return abs(X);
-}
-
-template <typename T, int N>
-constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
-length_vec_impl(vector<T, N> X) {
-#if (__has_builtin(__builtin_spirv_length))
-  return __builtin_spirv_length(X);
-#else
-  return sqrt(dot(X, X));
-#endif
-}
-
-template <typename T>
-constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
-distance_impl(T X, T Y) {
-  return length_impl(X - Y);
-}
-
-template <typename T, int N>
-constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
-distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
-  return length_vec_impl(X - Y);
-}
-
-template <typename T>
-constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
-reflect_impl(T I, T N) {
-  return I - 2 * N * I * N;
-}
-
-template <typename T, int L>
-constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
-#if (__has_builtin(__builtin_spirv_reflect))
-  return __builtin_spirv_reflect(I, N);
-#else
-  return I - 2 * N * dot(I, N);
-#endif
-}
-
 template <typename T> struct is_arithmetic {
   static const bool Value = __is_arithmetic(T);
 };
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
new file mode 100644
index 0000000000000..6783e23f6346d
--- /dev/null
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -0,0 +1,71 @@
+//===----- hlsl_intrinsic_helpers.h - HLSL helpers intrinsics -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _HLSL_HLSL_INTRINSIC_HELPERS_H_
+#define _HLSL_HLSL_INTRINSIC_HELPERS_H_
+
+namespace hlsl {
+namespace __detail {
+
+constexpr vector<uint, 4> d3d_color_to_ubyte4_impl(vector<float, 4> V) {
+  // Use the same scaling factor used by FXC, and DXC for DXIL
+  // (i.e., 255.001953)
+  // https://github.com/microsoft/DirectXShaderCompiler/blob/070d0d5a2beacef9eeb51037a9b04665716fd6f3/lib/HLSL/HLOperationLower.cpp#L666C1-L697C2
+  // The DXC implementation refers to a comment on the following stackoverflow
+  // discussion to justify the scaling factor: "Built-in rounding, necessary
+  // because of truncation. 0.001953 * 256 = 0.5"
+  // https://stackoverflow.com/questions/52103720/why-does-d3dcolortoubyte4-multiplies-components-by-255-001953f
+  return V.zyxw * 255.001953f;
+}
+
+template <typename T>
+constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
+length_impl(T X) {
+  return abs(X);
+}
+
+template <typename T, int N>
+constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
+length_vec_impl(vector<T, N> X) {
+#if (__has_builtin(__builtin_spirv_length))
+  return __builtin_spirv_length(X);
+#else
+  return sqrt(dot(X, X));
+#endif
+}
+
+template <typename T>
+constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
+distance_impl(T X, T Y) {
+  return length_impl(X - Y);
+}
+
+template <typename T, int N>
+constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
+distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
+  return length_vec_impl(X - Y);
+}
+
+template <typename T>
+constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
+reflect_impl(T I, T N) {
+  return I - 2 * N * I * N;
+}
+
+template <typename T, int L>
+constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
+#if (__has_builtin(__builtin_spirv_reflect))
+  return __builtin_spirv_reflect(I, N);
+#else
+  return I - 2 * N * dot(I, N);
+#endif
+}
+} // namespace __detail
+} // namespace hlsl
+
+#endif // _HLSL_HLSL_INTRINSIC_HELPERS_H_
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index fe9441080433d..cd6d836578787 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -9,7 +9,7 @@
 #ifndef _HLSL_HLSL_INTRINSICS_H_
 #define _HLSL_HLSL_INTRINSICS_H_
 
-#include "hlsl_detail.h"
+#include "hlsl/hlsl_intrinsic_helpers.h"
 
 namespace hlsl {
 

>From e69bf7e45c330e9de45ecc24f55470dcf34159c7 Mon Sep 17 00:00:00 2001
From: Chris Bieneman <chris.bieneman at me.com>
Date: Sun, 9 Mar 2025 11:57:28 -0500
Subject: [PATCH 4/5] NFC. Clang-format

---
 clang/lib/Headers/hlsl.h                       | 5 ++---
 clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h | 1 -
 2 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/clang/lib/Headers/hlsl.h b/clang/lib/Headers/hlsl.h
index 508e83e3a28e7..59ee31c61b102 100644
--- a/clang/lib/Headers/hlsl.h
+++ b/clang/lib/Headers/hlsl.h
@@ -16,10 +16,9 @@
 #pragma clang diagnostic ignored "-Whlsl-dxc-compatability"
 #endif
 
-
-#include "hlsl/hlsl_detail.h"
-#include "hlsl/hlsl_basic_types.h"
 #include "hlsl/hlsl_alias_intrinsics.h"
+#include "hlsl/hlsl_basic_types.h"
+#include "hlsl/hlsl_detail.h"
 #include "hlsl/hlsl_intrinsics.h"
 
 #if defined(__clang__)
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index 7a550a58e705c..89dfeb475488e 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -2123,7 +2123,6 @@ template <typename T, int Sz>
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
 vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>);
 
-
 /// \fn vector<T,Sz> select(vector<bool,Sz> Conds, T TrueVal,
 ///                         vector<T,Sz> FalseVals)
 /// \brief ternary operator for vectors. All vectors must be the same size.

>From a33f73b5bce08a01c324c69b46a370fd8165d43c Mon Sep 17 00:00:00 2001
From: Chris Bieneman <chris.bieneman at me.com>
Date: Sun, 9 Mar 2025 12:20:36 -0500
Subject: [PATCH 5/5] Make clang-format happy and don't break the include order

---
 clang/lib/Headers/hlsl.h | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/clang/lib/Headers/hlsl.h b/clang/lib/Headers/hlsl.h
index 59ee31c61b102..0d787a99d1019 100644
--- a/clang/lib/Headers/hlsl.h
+++ b/clang/lib/Headers/hlsl.h
@@ -16,9 +16,11 @@
 #pragma clang diagnostic ignored "-Whlsl-dxc-compatability"
 #endif
 
-#include "hlsl/hlsl_alias_intrinsics.h"
-#include "hlsl/hlsl_basic_types.h"
 #include "hlsl/hlsl_detail.h"
+
+#include "hlsl/hlsl_basic_types.h"
+
+#include "hlsl/hlsl_alias_intrinsics.h"
 #include "hlsl/hlsl_intrinsics.h"
 
 #if defined(__clang__)



More information about the cfe-commits mailing list