[clang] [HLSL] Implement support for HLSL intrinsic - select (PR #107129)

Sarah Spall via cfe-commits cfe-commits at lists.llvm.org
Fri Sep 6 18:45:58 PDT 2024


https://github.com/spall updated https://github.com/llvm/llvm-project/pull/107129

>From 3e0cd3c450eb4aa28742c4879733987e9e2692e7 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 28 Aug 2024 01:44:35 +0000
Subject: [PATCH 1/9] implement select intrinsic

---
 clang/include/clang/Basic/Builtins.td    |  6 ++
 clang/lib/CodeGen/CGBuiltin.cpp          | 41 ++++++++++++
 clang/lib/Headers/hlsl/hlsl_intrinsics.h | 24 +++++++
 clang/lib/Sema/SemaHLSL.cpp              | 84 ++++++++++++++++++++++++
 4 files changed, 155 insertions(+)

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 8668b25661dec8..7e89f84319877b 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4751,6 +4751,12 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_select"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e4d169d2ad6030..7fe198c207ce22 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18695,6 +18695,47 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},
         nullptr, "hlsl.saturate");
   }
+  case Builtin::BI__builtin_hlsl_select: {
+    Value *OpCond = EmitScalarExpr(E->getArg(0));
+    Value *OpTrue = EmitScalarExpr(E->getArg(1));
+    Value *OpFalse = EmitScalarExpr(E->getArg(2));
+    llvm::Type *TCond = OpCond->getType();
+
+    // if cond is a bool emit a select instruction
+    if (TCond->isIntegerTy(1))
+      return Builder.CreateSelect(OpCond, OpTrue, OpFalse);
+
+    // if cond is a vector of bools lower to a shufflevector
+    // todo check if that true and false are vectors
+    // todo check that the size of true and false and cond are the same
+    if (TCond->isVectorTy() &&
+	E->getArg(0)->getType()->getAs<VectorType>()->isBooleanType()) {
+      assert(OpTrue->getType()->isVectorTy() && OpFalse->getType()->isVectorTy() &&
+	     "Select's second and third operands must be vectors if first operand is a vector.");
+
+      auto *VecTyTrue = E->getArg(1)->getType()->getAs<VectorType>();
+      auto *VecTyFalse = E->getArg(2)->getType()->getAs<VectorType>();
+
+      assert(VecTyTrue->getElementType() == VecTyFalse->getElementType() &&
+	     "Select's second and third vectors need the same element types.");
+
+      const unsigned N = VecTyTrue->getNumElements();
+      assert(N == VecTyFalse->getNumElements() &&
+	     N == E->getArg(0)->getType()->getAs<VectorType>()->getNumElements() &&
+	     "Select requires vectors to be of the same size.");
+      
+      llvm::SmallVector<Value *> Mask;
+      for (unsigned I = 0; I < N; I++) {
+	Value *Index = ConstantInt::get(IntTy, I);
+	Value *IndexBool = Builder.CreateExtractElement(OpCond, Index);
+	Mask.push_back(Builder.CreateSelect(IndexBool, Index, ConstantInt::get(IntTy, I + N)));
+      }
+      
+      return Builder.CreateShuffleVector(OpTrue, OpFalse, BuildVector(Mask));
+    }
+    
+    llvm_unreachable("Select requires a bool or vector of bools as its first operand.");
+  }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
     return EmitRuntimeCall(CGM.CreateRuntimeFunction(
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 6d38b668fe770e..4319d9775c4bae 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1603,6 +1603,30 @@ double3 saturate(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
 double4 saturate(double4);
 
+//===----------------------------------------------------------------------===//
+// select builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T select(bool Cond, T TrueVal, T FalseVal)
+/// \brief ternary operator.
+/// \param Cond The Condition input value.
+/// \param TrueVal The Value returned if Cond is true.
+/// \param FalseVal The Value returned if Cond is false.
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+template<typename T>
+T select(bool, T, T);
+  
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals, vector<T,Sz>, FalseVals)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVals The vector values are chosen from when conditions are true.
+/// \param FalseVals The vector values are chosen from when conditions are false.
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+template<typename T, int Sz>
+vector<T,Sz> select(vector<bool,Sz>, vector<T,Sz>, vector<T,Sz>);
+  
 //===----------------------------------------------------------------------===//
 // sin builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index fabc6f32906b10..0da1d91ab7aa6b 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1512,6 +1512,66 @@ void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
   TheCall->setType(ReturnType);
 }
 
+bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
+  assert(TheCall->getNumArgs() == 3);
+  Expr *Arg1 = TheCall->getArg(1);
+  Expr *Arg2 = TheCall->getArg(2);
+  if(!S->Context.hasSameUnqualifiedType(Arg1->getType(),
+					Arg2->getType())) {
+    S->Diag(TheCall->getBeginLoc(),
+	    diag::err_typecheck_call_different_arg_types)
+      << Arg1->getType() << Arg2->getType()
+      << Arg1->getSourceRange() << Arg2->getSourceRange();
+    return true;
+  }
+
+  TheCall->setType(Arg1->getType());
+  return false;
+}
+
+bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
+  assert(TheCall->getNumArgs() == 3);
+  Expr *Arg1 = TheCall->getArg(1);
+  Expr *Arg2 = TheCall->getArg(2);
+  if(!Arg1->getType()->isVectorType()) {
+    S->Diag(Arg1->getBeginLoc(),
+	    diag::err_builtin_non_vector_type)
+      << "Second" << "__builtin_hlsl_select" << Arg1->getType()
+      << Arg1->getSourceRange();
+    return true;
+  }
+  
+  if(!Arg2->getType()->isVectorType()) {
+    S->Diag(Arg2->getBeginLoc(),
+	    diag::err_builtin_non_vector_type)
+      << "Third" << "__builtin_hlsl_select" << Arg2->getType()
+      << Arg2->getSourceRange();
+    return true;
+  }
+
+  if(!S->Context.hasSameUnqualifiedType(Arg1->getType(),
+					Arg2->getType())) {
+    S->Diag(TheCall->getBeginLoc(),
+	    diag::err_typecheck_call_different_arg_types)
+      << Arg1->getType() << Arg2->getType()
+      << Arg1->getSourceRange() << Arg2->getSourceRange();
+    return true;
+  }
+
+  // caller has checked that Arg0 is a vector.
+  // check all three args have the same length.
+  if(TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
+     Arg1->getType()->getAs<VectorType>()->getNumElements()) {
+    S->Diag(TheCall->getBeginLoc(),
+	    diag::err_typecheck_vector_lengths_not_equal)
+      << TheCall->getArg(0)->getType() << Arg1->getType()
+      << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
+    return true;
+  }
+
+  return false;
+}
+
 // Note: returning true in this case results in CheckBuiltinFunctionCall
 // returning an ExprError
 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -1545,6 +1605,30 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_saturate:
+  case Builtin::BI__builtin_hlsl_select: {
+    if (SemaRef.checkArgCount(TheCall, 3))
+      return true;
+    QualType ArgTy = TheCall->getArg(0)->getType();
+    if (ArgTy->isBooleanType()) {
+      if (CheckBoolSelect(&SemaRef, TheCall))
+	return true;
+    } else if (ArgTy->isVectorType() &&
+	       ArgTy->getAs<VectorType>()->getElementType()->isBooleanType()) {
+      if (CheckVectorSelect(&SemaRef, TheCall))
+	return true;
+    } else { // first operand is not a bool or a vector of bools.
+      SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
+		   diag::err_typecheck_convert_incompatible)
+	<< TheCall->getArg(0)->getType() << SemaRef.Context.getBOOLType()
+	<< 1 << 0 << 0;
+      SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
+		   diag::err_builtin_non_vector_type)
+	<< "First" << "__builtin_hlsl_select" << TheCall->getArg(0)->getType()
+	<< TheCall->getArg(0)->getSourceRange();
+      return true;
+    }
+    break;
+  }
   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
     if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
       return true;

>From 52f4f39e1d558ef8ab37e925db2a9aad1de87ce0 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 30 Aug 2024 21:49:20 +0000
Subject: [PATCH 2/9] tests for select intrinsic

---
 clang/test/CodeGenHLSL/builtins/select.hlsl   | 76 ++++++++++++++++
 .../test/SemaHLSL/BuiltIns/select-errors.hlsl | 90 +++++++++++++++++++
 2 files changed, 166 insertions(+)
 create mode 100644 clang/test/CodeGenHLSL/builtins/select.hlsl
 create mode 100644 clang/test/SemaHLSL/BuiltIns/select-errors.hlsl

diff --git a/clang/test/CodeGenHLSL/builtins/select.hlsl b/clang/test/CodeGenHLSL/builtins/select.hlsl
new file mode 100644
index 00000000000000..eeafd8ea852ea3
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/select.hlsl
@@ -0,0 +1,76 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK
+
+// CHECK: %hlsl.select = select i1
+// CHECK: ret i32 %hlsl.select
+int test_select_bool_int(bool cond0, int tVal, int fVal) { return select<int>(cond0, tVal, fVal); }
+
+// CHECK: %hlsl.select = select i1
+// CHECK: ret <2 x i32> %hlsl.select
+vector<int,2> test_select_bool_vector(bool cond0, vector<int, 2> tVal, vector<int, 2> fVal) { return select<vector<int,2> >(cond0, tVal, fVal); }
+
+// CHECK: %4 = extractelement <1 x i1> %extractvec, i32 0
+// CHECK: %5 = extractelement <1 x i32> %2, i32 0
+// CHECK: %6 = extractelement <1 x i32> %3, i32 0
+// CHECK: %7 = select i1 %4, i32 %5, i32 %6
+// CHECK: %8 = insertelement <1 x i32> poison, i32 %7, i32 0
+// CHECK: ret <1 x i32> %8
+vector<int,1> test_select_vector_1(vector<bool,1> cond0, vector<int,1> tVals, vector<int,1> fVals) { return select<int,1>(cond0, tVals, fVals); }
+
+// CHECK: %4 = extractelement <2 x i1> %extractvec, i32 0
+// CHECK: %5 = extractelement <2 x i32> %2, i32 0
+// CHECK: %6 = extractelement <2 x i32> %3, i32 0
+// CHECK: %7 = select i1 %4, i32 %5, i32 %6
+// CHECK: %8 = insertelement <2 x i32> poison, i32 %7, i32 0
+// CHECK: %9 = extractelement <2 x i1> %extractvec, i32 1
+// CHECK: %10 = extractelement <2 x i32> %2, i32 1
+// CHECK: %11 = extractelement <2 x i32> %3, i32 1
+// CHECK: %12 = select i1 %9, i32 %10, i32 %11
+// CHECK: %13 = insertelement <2 x i32> %8, i32 %12, i32 1
+// CHECK: ret <2 x i32> %13
+vector<int,2> test_select_vector_2(vector<bool, 2> cond0, vector<int, 2> tVals, vector<int, 2> fVals) { return select<int,2>(cond0, tVals, fVals); }
+
+// CHECK: %4 = extractelement <3 x i1> %extractvec, i32 0
+// CHECK: %5 = extractelement <3 x i32> %2, i32 0
+// CHECK: %6 = extractelement <3 x i32> %3, i32 0
+// CHECK: %7 = select i1 %4, i32 %5, i32 %6
+// CHECK: %8 = insertelement <3 x i32> poison, i32 %7, i32 0
+// CHECK: %9 = extractelement <3 x i1> %extractvec, i32 1
+// CHECK: %10 = extractelement <3 x i32> %2, i32 1
+// CHECK: %11 = extractelement <3 x i32> %3, i32 1
+// CHECK: %12 = select i1 %9, i32 %10, i32 %11
+// CHECK: %13 = insertelement <3 x i32> %8, i32 %12, i32 1
+// CHECK: %14 = extractelement <3 x i1> %extractvec, i32 2
+// CHECK: %15 = extractelement <3 x i32> %2, i32 2
+// CHECK: %16 = extractelement <3 x i32> %3, i32 2
+// CHECK: %17 = select i1 %14, i32 %15, i32 %16
+// CHECK: %18 = insertelement <3 x i32> %13, i32 %17, i32 2
+// CHECK: ret <3 x i32> %18
+vector<int,3> test_select_vector_3(vector<bool, 3> cond0, vector<int, 3> tVals, vector<int, 3> fVals) { return select<int,3>(cond0, tVals, fVals); }
+
+// CHECK: %4 = extractelement <4 x i1> %extractvec, i32 0
+// CHECK: %5 = extractelement <4 x i32> %2, i32 0
+// CHECK: %6 = extractelement <4 x i32> %3, i32 0
+// CHECK: %7 = select i1 %4, i32 %5, i32 %6
+// CHECK: %8 = insertelement <4 x i32> poison, i32 %7, i32 0
+// CHECK: %9 = extractelement <4 x i1> %extractvec, i32 1
+// CHECK: %10 = extractelement <4 x i32> %2, i32 1
+// CHECK: %11 = extractelement <4 x i32> %3, i32 1
+// CHECK: %12 = select i1 %9, i32 %10, i32 %11
+// CHECK: %13 = insertelement <4 x i32> %8, i32 %12, i32 1
+// CHECK: %14 = extractelement <4 x i1> %extractvec, i32 2
+// CHECK: %15 = extractelement <4 x i32> %2, i32 2
+// CHECK: %16 = extractelement <4 x i32> %3, i32 2
+// CHECK: %17 = select i1 %14, i32 %15, i32 %16
+// CHECK: %18 = insertelement <4 x i32> %13, i32 %17, i32 2
+// CHECK: %19 = extractelement <4 x i1> %extractvec, i32 3
+// CHECK: %20 = extractelement <4 x i32> %2, i32 3
+// CHECK: %21 = extractelement <4 x i32> %3, i32 3
+// CHECK: %22 = select i1 %19, i32 %20, i32 %21
+// CHECK: %23 = insertelement <4 x i32> %18, i32 %22, i32 3
+// CHECK: ret <4 x i32> %23
+vector<int,4> test_select_vector_4(vector<bool, 4> cond0, vector<int, 4> tVals, vector<int, 4> fVals) { return select<int,4>(cond0, tVals, fVals); }
+
+
+
diff --git a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
new file mode 100644
index 00000000000000..5637c5f8176e90
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
@@ -0,0 +1,90 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+int test_no_arg() {
+  return select();
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 0 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 0 were provided}}
+}
+
+int test_too_few_args(bool p0) {
+  return select(p0);
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 1 was provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 1 was provided}}
+}
+
+int test_too_many_args(bool p0, int t0, int f0, int g0) {
+  return select<int>(p0, t0, f0, g0);
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 4 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 4 were provided}}
+}
+
+int test_select_first_arg_wrong_type(vector<int,1> p0, int t0, int f0) {
+  return select(p0, t0, f0);
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: no known conversion from 'vector<int, 1>' (vector of 1 'int' value) to 'bool' for 1st argument}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could not match 'vector<T, Sz>' against 'int'}}
+}
+
+vector<int,1> test_select_bool_vals_diff_vecs(bool p0, vector<int,1> t0, vector<int,2> f0) {
+  return select<vector<int,1> >(p0, t0, f0);
+  // expected-warning at -1 {{implicit conversion truncates vector: 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>' (vector of 1 'int' value)}}
+}
+
+vector<int,2> test_select_vector_vals_not_vecs(vector<bool,2> p0, int t0, int f0) {
+  return select(p0, t0, f0);
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could not match 'vector<T, Sz>' against 'int'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: no known conversion from 'vector<bool, 2>' (vector of 2 'bool' values) to 'bool' for 1st argument}}
+}
+
+vector<int,1> test_select_vector_vals_wrong_size(vector<bool,2> p0, vector<int,1> t0, vector<int,2> f0) {
+  return select<int,1>(p0, t0, f0); // produce warnings
+  // expected-warning at -1 {{implicit conversion truncates vector: 'vector<bool, 2>' (vector of 2 'bool' values) to 'vector<bool, 1>' (vector of 1 'bool' value)}}
+  // expected-warning at -2 {{implicit conversion truncates vector: 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>' (vector of 1 'int' value)}}
+}
+
+// __builtin_hlsl_select tests
+int test_select_builtin_wrong_arg_count(bool p0, int t0) {
+  return __builtin_hlsl_select(p0, t0);
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 2}}
+}
+
+// not a bool or a vector of bool. should be 2 errors.
+int test_select_builtin_first_arg_wrong_type1(int p0, int t0, int f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{passing 'int' to parameter of incompatible type 'bool'}}
+  // expected-error at -2 {{First argument to __builtin_hlsl_select must be of vector type}}
+  }
+
+int test_select_builtin_first_arg_wrong_type2(vector<int,1> p0, int t0, int f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{passing 'vector<int, 1>' (vector of 1 'int' value) to parameter of incompatible type 'bool'}}
+  // expected-error at -2 {{First argument to __builtin_hlsl_select must be of vector type}}
+}
+
+// if a bool last 2 args are of same type
+int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{arguments are of different types ('int' vs 'double')}}
+}
+
+// if a vector second arg isnt a vector
+vector<int,2> test_select_builtin_second_arg_not_vector(vector<bool,2> p0, int t0, vector<int,2> f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{Second argument to __builtin_hlsl_select must be of vector type}}
+}
+
+// if a vector third arg isn't a vector
+vector<int,2> test_select_builtin_second_arg_not_vector(vector<bool,2> p0, vector<int,2> t0, int f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{Third argument to __builtin_hlsl_select must be of vector type}}
+}
+
+// if vector last 2 aren't same type (so both are vectors but wrong type)
+vector<int,2> test_select_builtin_diff_types(vector<bool,1> p0, vector<int,1> t0, vector<float,1> f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{arguments are of different types ('vector<int, [...]>' vs 'vector<float, [...]>')}}
+}
\ No newline at end of file

>From 7ff44a894c59c34d0aacba57d5eb2d6f63e9cf6f Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 30 Aug 2024 21:54:49 +0000
Subject: [PATCH 3/9] fix bugs revealed during testing

---
 clang/lib/CodeGen/CGBuiltin.cpp          | 17 ++++++++++-------
 clang/lib/Headers/hlsl/hlsl_intrinsics.h |  4 ++--
 clang/lib/Sema/SemaHLSL.cpp              |  2 +-
 3 files changed, 13 insertions(+), 10 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7fe198c207ce22..1aaa668607dab0 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18703,13 +18703,13 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
 
     // if cond is a bool emit a select instruction
     if (TCond->isIntegerTy(1))
-      return Builder.CreateSelect(OpCond, OpTrue, OpFalse);
+      return Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
 
     // if cond is a vector of bools lower to a shufflevector
     // todo check if that true and false are vectors
     // todo check that the size of true and false and cond are the same
     if (TCond->isVectorTy() &&
-	E->getArg(0)->getType()->getAs<VectorType>()->isBooleanType()) {
+	E->getArg(0)->getType()->getAs<VectorType>()->getElementType()->isBooleanType()) {
       assert(OpTrue->getType()->isVectorTy() && OpFalse->getType()->isVectorTy() &&
 	     "Select's second and third operands must be vectors if first operand is a vector.");
 
@@ -18723,15 +18723,18 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
       assert(N == VecTyFalse->getNumElements() &&
 	     N == E->getArg(0)->getType()->getAs<VectorType>()->getNumElements() &&
 	     "Select requires vectors to be of the same size.");
-      
-      llvm::SmallVector<Value *> Mask;
+
+      llvm::Value *Result = llvm::PoisonValue::get(llvm::FixedVectorType::get(IntTy, N));
       for (unsigned I = 0; I < N; I++) {
 	Value *Index = ConstantInt::get(IntTy, I);
 	Value *IndexBool = Builder.CreateExtractElement(OpCond, Index);
-	Mask.push_back(Builder.CreateSelect(IndexBool, Index, ConstantInt::get(IntTy, I + N)));
+	Value *TVal = Builder.CreateExtractElement(OpTrue, Index);
+	Value *FVal = Builder.CreateExtractElement(OpFalse, Index);
+	Value *IndexSelect = Builder.CreateSelect(IndexBool, TVal, FVal);
+	Result = Builder.CreateInsertElement(Result, IndexSelect, Index); 
       }
-      
-      return Builder.CreateShuffleVector(OpTrue, OpFalse, BuildVector(Mask));
+
+      return Result;
     }
     
     llvm_unreachable("Select requires a bool or vector of bools as its first operand.");
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 4319d9775c4bae..da14e837f7bd0a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1613,8 +1613,8 @@ double4 saturate(double4);
 /// \param TrueVal The Value returned if Cond is true.
 /// \param FalseVal The Value returned if Cond is false.
 
+template <typename T>
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
-template<typename T>
 T select(bool, T, T);
   
 /// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals, vector<T,Sz>, FalseVals)
@@ -1623,8 +1623,8 @@ T select(bool, T, T);
 /// \param TrueVals The vector values are chosen from when conditions are true.
 /// \param FalseVals The vector values are chosen from when conditions are false.
 
+template <typename T, int Sz>
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
-template<typename T, int Sz>
 vector<T,Sz> select(vector<bool,Sz>, vector<T,Sz>, vector<T,Sz>);
   
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 0da1d91ab7aa6b..62c825c9c852a7 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1619,7 +1619,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     } else { // first operand is not a bool or a vector of bools.
       SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
 		   diag::err_typecheck_convert_incompatible)
-	<< TheCall->getArg(0)->getType() << SemaRef.Context.getBOOLType()
+	<< TheCall->getArg(0)->getType() << getASTContext().BoolTy
 	<< 1 << 0 << 0;
       SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
 		   diag::err_builtin_non_vector_type)

>From 735b0854659c734a9edc442ba6f20f11be2ee217 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Tue, 3 Sep 2024 21:53:43 +0000
Subject: [PATCH 4/9] move select case above saturate case so saturate falls
 through correctly.

---
 clang/lib/Sema/SemaHLSL.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 62c825c9c852a7..b3b05db3b757b5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1604,7 +1604,6 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
-  case Builtin::BI__builtin_hlsl_elementwise_saturate:
   case Builtin::BI__builtin_hlsl_select: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
@@ -1629,6 +1628,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     }
     break;
   }
+  case Builtin::BI__builtin_hlsl_elementwise_saturate:
   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
     if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
       return true;

>From f8f4a4e31aa4c49fc091a5b916013e6d66256ace Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Tue, 3 Sep 2024 22:00:15 +0000
Subject: [PATCH 5/9] remove no longer needed todo comments.

---
 clang/lib/CodeGen/CGBuiltin.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 1aaa668607dab0..6709e15d383427 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18706,8 +18706,6 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
       return Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
 
     // if cond is a vector of bools lower to a shufflevector
-    // todo check if that true and false are vectors
-    // todo check that the size of true and false and cond are the same
     if (TCond->isVectorTy() &&
 	E->getArg(0)->getType()->getAs<VectorType>()->getElementType()->isBooleanType()) {
       assert(OpTrue->getType()->isVectorTy() && OpFalse->getType()->isVectorTy() &&

>From 4803c5b111e989b6e50c1150609a03e3ecab7148 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 5 Sep 2024 18:22:04 +0000
Subject: [PATCH 6/9] address comments on PR. Simplify codegen of hlsl_select
 because llvm select instr already does exactly what we want. Add handling for
 aggregate types.

---
 .../clang/Basic/DiagnosticSemaKinds.td        |  3 +
 clang/lib/CodeGen/CGBuiltin.cpp               | 70 ++++++-------
 clang/lib/CodeGen/CodeGenFunction.h           |  3 +-
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  6 +-
 clang/lib/Sema/SemaHLSL.cpp                   | 59 ++++++-----
 clang/test/CodeGenHLSL/builtins/select.hlsl   | 97 +++++++------------
 .../test/SemaHLSL/BuiltIns/select-errors.hlsl | 97 +++++++++++++------
 7 files changed, 172 insertions(+), 163 deletions(-)

diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index dcb49d8a67604a..68c6993089dcb8 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -9206,6 +9206,9 @@ def err_typecheck_expect_scalar_operand : Error<
   "operand of type %0 where arithmetic or pointer type is required">;
 def err_typecheck_cond_incompatible_operands : Error<
   "incompatible operand types%diff{ ($ and $)|}0,1">;
+def err_typecheck_expect_scalar_or_vector : Error<
+  "invalid operand of type %0 where %1 or "
+  "a vector of such type is required">;
 def err_typecheck_expect_flt_or_vector : Error<
   "invalid operand of type %0 where floating, complex or "
   "a vector of such types is required">;
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 6709e15d383427..97e56b411bff46 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -6241,8 +6241,20 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
   }
 
   // EmitHLSLBuiltinExpr will check getLangOpts().HLSL
-  if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E))
-    return RValue::get(V);
+  if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E, ReturnValue)) {
+    switch (EvalKind) {
+    case TEK_Scalar:
+      if (V->getType()->isVoidTy())
+        return RValue::get(nullptr);
+      return RValue::get(V);
+    case TEK_Aggregate:
+      return RValue::getAggregate(ReturnValue.getAddress(),
+                                  ReturnValue.isVolatile());
+    case TEK_Complex:
+      llvm_unreachable("No current hlsl builtin returns complex");
+    }
+    llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr");
+  }
 
   if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice)
     return EmitHipStdParUnsupportedBuiltin(this, FD);
@@ -18508,7 +18520,8 @@ Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
 }
 
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
-                                            const CallExpr *E) {
+                                            const CallExpr *E,
+					    ReturnValueSlot ReturnValue) {
   if (!getLangOpts().HLSL)
     return nullptr;
 
@@ -18697,45 +18710,20 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
   }
   case Builtin::BI__builtin_hlsl_select: {
     Value *OpCond = EmitScalarExpr(E->getArg(0));
-    Value *OpTrue = EmitScalarExpr(E->getArg(1));
-    Value *OpFalse = EmitScalarExpr(E->getArg(2));
-    llvm::Type *TCond = OpCond->getType();
-
-    // if cond is a bool emit a select instruction
-    if (TCond->isIntegerTy(1))
-      return Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
-
-    // if cond is a vector of bools lower to a shufflevector
-    if (TCond->isVectorTy() &&
-	E->getArg(0)->getType()->getAs<VectorType>()->getElementType()->isBooleanType()) {
-      assert(OpTrue->getType()->isVectorTy() && OpFalse->getType()->isVectorTy() &&
-	     "Select's second and third operands must be vectors if first operand is a vector.");
-
-      auto *VecTyTrue = E->getArg(1)->getType()->getAs<VectorType>();
-      auto *VecTyFalse = E->getArg(2)->getType()->getAs<VectorType>();
-
-      assert(VecTyTrue->getElementType() == VecTyFalse->getElementType() &&
-	     "Select's second and third vectors need the same element types.");
-
-      const unsigned N = VecTyTrue->getNumElements();
-      assert(N == VecTyFalse->getNumElements() &&
-	     N == E->getArg(0)->getType()->getAs<VectorType>()->getNumElements() &&
-	     "Select requires vectors to be of the same size.");
-
-      llvm::Value *Result = llvm::PoisonValue::get(llvm::FixedVectorType::get(IntTy, N));
-      for (unsigned I = 0; I < N; I++) {
-	Value *Index = ConstantInt::get(IntTy, I);
-	Value *IndexBool = Builder.CreateExtractElement(OpCond, Index);
-	Value *TVal = Builder.CreateExtractElement(OpTrue, Index);
-	Value *FVal = Builder.CreateExtractElement(OpFalse, Index);
-	Value *IndexSelect = Builder.CreateSelect(IndexBool, TVal, FVal);
-	Result = Builder.CreateInsertElement(Result, IndexSelect, Index); 
-      }
-
-      return Result;
-    }
+    RValue RValTrue = EmitAnyExpr(E->getArg(1));
+    Value *OpTrue = RValTrue.isScalar() ? RValTrue.getScalarVal()
+      : RValTrue.getAggregatePointer(E->getArg(1)->getType(), *this);
+    RValue RValFalse = EmitAnyExpr(E->getArg(2));
+    Value *OpFalse = RValFalse.isScalar() ? RValFalse.getScalarVal()
+      : RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
     
-    llvm_unreachable("Select requires a bool or vector of bools as its first operand.");
+    Value *SelectVal = Builder.CreateSelect(OpCond, OpTrue, OpFalse,
+					    "hlsl.select");
+    if (!RValTrue.isScalar())
+      Builder.CreateStore(SelectVal, ReturnValue.getAddress(),
+			  ReturnValue.isVolatile());
+
+    return SelectVal;
   }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
     return EmitRuntimeCall(CGM.CreateRuntimeFunction(
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 368fc112187ffc..81a0b666183cc7 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4700,7 +4700,8 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
-  llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
+  llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
+				   ReturnValueSlot ReturnValue);
   llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
                                            const CallExpr *E);
   llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index da14e837f7bd0a..d405d604cf2ab5 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1617,11 +1617,13 @@ template <typename T>
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
 T select(bool, T, T);
   
-/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals, vector<T,Sz>, FalseVals)
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
+///                         vector<T,Sz> FalseVals)
 /// \brief ternary operator for vectors. All vectors must be the same size.
 /// \param Conds The Condition input values.
 /// \param TrueVals The vector values are chosen from when conditions are true.
-/// \param FalseVals The vector values are chosen from when conditions are false.
+/// \param FalseVals The vector values are chosen from when conditions are
+/// false.
 
 template <typename T, int Sz>
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index b3b05db3b757b5..803d63e31b8321 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1512,7 +1512,24 @@ void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
   TheCall->setType(ReturnType);
 }
 
-bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
+static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
+				unsigned ArgIndex) {
+  assert(TheCall->getNumArgs() >= ArgIndex);
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
+  auto *VTy = ArgType->getAs<VectorType>();
+  // not the scalar or vector<scalar>
+  if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
+	(VTy && S->Context.hasSameUnqualifiedType(VTy->getElementType(),
+						  Scalar)))) {
+    S->Diag(TheCall->getArg(0)->getBeginLoc(),
+		   diag::err_typecheck_expect_scalar_or_vector)
+	<< ArgType << Scalar;
+    return true;
+  }
+  return false;
+}
+
+static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() == 3);
   Expr *Arg1 = TheCall->getArg(1);
   Expr *Arg2 = TheCall->getArg(2);
@@ -1529,11 +1546,11 @@ bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
   return false;
 }
 
-bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
+static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() == 3);
   Expr *Arg1 = TheCall->getArg(1);
   Expr *Arg2 = TheCall->getArg(2);
-  if(!Arg1->getType()->isVectorType()) {
+  if (!Arg1->getType()->isVectorType()) {
     S->Diag(Arg1->getBeginLoc(),
 	    diag::err_builtin_non_vector_type)
       << "Second" << "__builtin_hlsl_select" << Arg1->getType()
@@ -1541,7 +1558,7 @@ bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
     return true;
   }
   
-  if(!Arg2->getType()->isVectorType()) {
+  if (!Arg2->getType()->isVectorType()) {
     S->Diag(Arg2->getBeginLoc(),
 	    diag::err_builtin_non_vector_type)
       << "Third" << "__builtin_hlsl_select" << Arg2->getType()
@@ -1549,8 +1566,8 @@ bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
     return true;
   }
 
-  if(!S->Context.hasSameUnqualifiedType(Arg1->getType(),
-					Arg2->getType())) {
+  if (!S->Context.hasSameUnqualifiedType(Arg1->getType(),
+					 Arg2->getType())) {
     S->Diag(TheCall->getBeginLoc(),
 	    diag::err_typecheck_call_different_arg_types)
       << Arg1->getType() << Arg2->getType()
@@ -1560,15 +1577,15 @@ bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
 
   // caller has checked that Arg0 is a vector.
   // check all three args have the same length.
-  if(TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
-     Arg1->getType()->getAs<VectorType>()->getNumElements()) {
+  if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
+      Arg1->getType()->getAs<VectorType>()->getNumElements()) {
     S->Diag(TheCall->getBeginLoc(),
 	    diag::err_typecheck_vector_lengths_not_equal)
       << TheCall->getArg(0)->getType() << Arg1->getType()
       << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
     return true;
   }
-
+  TheCall->setType(Arg1->getType());
   return false;
 }
 
@@ -1607,25 +1624,15 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_select: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
+    if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
+      return true;
     QualType ArgTy = TheCall->getArg(0)->getType();
-    if (ArgTy->isBooleanType()) {
-      if (CheckBoolSelect(&SemaRef, TheCall))
-	return true;
-    } else if (ArgTy->isVectorType() &&
-	       ArgTy->getAs<VectorType>()->getElementType()->isBooleanType()) {
-      if (CheckVectorSelect(&SemaRef, TheCall))
-	return true;
-    } else { // first operand is not a bool or a vector of bools.
-      SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
-		   diag::err_typecheck_convert_incompatible)
-	<< TheCall->getArg(0)->getType() << getASTContext().BoolTy
-	<< 1 << 0 << 0;
-      SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
-		   diag::err_builtin_non_vector_type)
-	<< "First" << "__builtin_hlsl_select" << TheCall->getArg(0)->getType()
-	<< TheCall->getArg(0)->getSourceRange();
+    if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall))
+      return true;
+    auto *VTy = ArgTy->getAs<VectorType>();
+    if (VTy && VTy->getElementType()->isBooleanType() &&
+	CheckVectorSelect(&SemaRef, TheCall))
       return true;
-    }
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_saturate:
diff --git a/clang/test/CodeGenHLSL/builtins/select.hlsl b/clang/test/CodeGenHLSL/builtins/select.hlsl
index eeafd8ea852ea3..2d5a49ab2090b9 100644
--- a/clang/test/CodeGenHLSL/builtins/select.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/select.hlsl
@@ -4,73 +4,42 @@
 
 // CHECK: %hlsl.select = select i1
 // CHECK: ret i32 %hlsl.select
-int test_select_bool_int(bool cond0, int tVal, int fVal) { return select<int>(cond0, tVal, fVal); }
+int test_select_bool_int(bool cond0, int tVal, int fVal) {
+  return select<int>(cond0, tVal, fVal); }
 
+struct S { int a; };
 // CHECK: %hlsl.select = select i1
-// CHECK: ret <2 x i32> %hlsl.select
-vector<int,2> test_select_bool_vector(bool cond0, vector<int, 2> tVal, vector<int, 2> fVal) { return select<vector<int,2> >(cond0, tVal, fVal); }
-
-// CHECK: %4 = extractelement <1 x i1> %extractvec, i32 0
-// CHECK: %5 = extractelement <1 x i32> %2, i32 0
-// CHECK: %6 = extractelement <1 x i32> %3, i32 0
-// CHECK: %7 = select i1 %4, i32 %5, i32 %6
-// CHECK: %8 = insertelement <1 x i32> poison, i32 %7, i32 0
-// CHECK: ret <1 x i32> %8
-vector<int,1> test_select_vector_1(vector<bool,1> cond0, vector<int,1> tVals, vector<int,1> fVals) { return select<int,1>(cond0, tVals, fVals); }
-
-// CHECK: %4 = extractelement <2 x i1> %extractvec, i32 0
-// CHECK: %5 = extractelement <2 x i32> %2, i32 0
-// CHECK: %6 = extractelement <2 x i32> %3, i32 0
-// CHECK: %7 = select i1 %4, i32 %5, i32 %6
-// CHECK: %8 = insertelement <2 x i32> poison, i32 %7, i32 0
-// CHECK: %9 = extractelement <2 x i1> %extractvec, i32 1
-// CHECK: %10 = extractelement <2 x i32> %2, i32 1
-// CHECK: %11 = extractelement <2 x i32> %3, i32 1
-// CHECK: %12 = select i1 %9, i32 %10, i32 %11
-// CHECK: %13 = insertelement <2 x i32> %8, i32 %12, i32 1
-// CHECK: ret <2 x i32> %13
-vector<int,2> test_select_vector_2(vector<bool, 2> cond0, vector<int, 2> tVals, vector<int, 2> fVals) { return select<int,2>(cond0, tVals, fVals); }
-
-// CHECK: %4 = extractelement <3 x i1> %extractvec, i32 0
-// CHECK: %5 = extractelement <3 x i32> %2, i32 0
-// CHECK: %6 = extractelement <3 x i32> %3, i32 0
-// CHECK: %7 = select i1 %4, i32 %5, i32 %6
-// CHECK: %8 = insertelement <3 x i32> poison, i32 %7, i32 0
-// CHECK: %9 = extractelement <3 x i1> %extractvec, i32 1
-// CHECK: %10 = extractelement <3 x i32> %2, i32 1
-// CHECK: %11 = extractelement <3 x i32> %3, i32 1
-// CHECK: %12 = select i1 %9, i32 %10, i32 %11
-// CHECK: %13 = insertelement <3 x i32> %8, i32 %12, i32 1
-// CHECK: %14 = extractelement <3 x i1> %extractvec, i32 2
-// CHECK: %15 = extractelement <3 x i32> %2, i32 2
-// CHECK: %16 = extractelement <3 x i32> %3, i32 2
-// CHECK: %17 = select i1 %14, i32 %15, i32 %16
-// CHECK: %18 = insertelement <3 x i32> %13, i32 %17, i32 2
-// CHECK: ret <3 x i32> %18
-vector<int,3> test_select_vector_3(vector<bool, 3> cond0, vector<int, 3> tVals, vector<int, 3> fVals) { return select<int,3>(cond0, tVals, fVals); }
+// CHECK: store ptr %hlsl.select
+// CHECK: ret void
+struct S test_select_infer(bool cond0, struct S tVal, struct S fVal) {
+  return select(cond0, tVal, fVal); }
 
-// CHECK: %4 = extractelement <4 x i1> %extractvec, i32 0
-// CHECK: %5 = extractelement <4 x i32> %2, i32 0
-// CHECK: %6 = extractelement <4 x i32> %3, i32 0
-// CHECK: %7 = select i1 %4, i32 %5, i32 %6
-// CHECK: %8 = insertelement <4 x i32> poison, i32 %7, i32 0
-// CHECK: %9 = extractelement <4 x i1> %extractvec, i32 1
-// CHECK: %10 = extractelement <4 x i32> %2, i32 1
-// CHECK: %11 = extractelement <4 x i32> %3, i32 1
-// CHECK: %12 = select i1 %9, i32 %10, i32 %11
-// CHECK: %13 = insertelement <4 x i32> %8, i32 %12, i32 1
-// CHECK: %14 = extractelement <4 x i1> %extractvec, i32 2
-// CHECK: %15 = extractelement <4 x i32> %2, i32 2
-// CHECK: %16 = extractelement <4 x i32> %3, i32 2
-// CHECK: %17 = select i1 %14, i32 %15, i32 %16
-// CHECK: %18 = insertelement <4 x i32> %13, i32 %17, i32 2
-// CHECK: %19 = extractelement <4 x i1> %extractvec, i32 3
-// CHECK: %20 = extractelement <4 x i32> %2, i32 3
-// CHECK: %21 = extractelement <4 x i32> %3, i32 3
-// CHECK: %22 = select i1 %19, i32 %20, i32 %21
-// CHECK: %23 = insertelement <4 x i32> %18, i32 %22, i32 3
-// CHECK: ret <4 x i32> %23
-vector<int,4> test_select_vector_4(vector<bool, 4> cond0, vector<int, 4> tVals, vector<int, 4> fVals) { return select<int,4>(cond0, tVals, fVals); }
+// CHECK: %hlsl.select = select i1
+// CHECK: ret <2 x i32> %hlsl.select
+vector<int,2> test_select_bool_vector(bool cond0, vector<int, 2> tVal,
+                                      vector<int, 2> fVal) {
+  return select<vector<int,2> >(cond0, tVal, fVal); }
 
+// CHECK: %hlsl.select = select <1 x i1>
+// CHECK: ret <1 x i32> %hlsl.select
+vector<int,1> test_select_vector_1(vector<bool,1> cond0, vector<int,1> tVals,
+                                   vector<int,1> fVals) {
+  return select<int,1>(cond0, tVals, fVals); }
 
+// CHECK: %hlsl.select = select <2 x i1>
+// CHECK: ret <2 x i32> %hlsl.select
+vector<int,2> test_select_vector_2(vector<bool, 2> cond0, vector<int, 2> tVals,
+                                   vector<int, 2> fVals) {
+  return select<int,2>(cond0, tVals, fVals); }
+
+// CHECK: %hlsl.select = select <3 x i1>
+// CHECK: ret <3 x i32> %hlsl.select
+vector<int,3> test_select_vector_3(vector<bool, 3> cond0, vector<int, 3> tVals,
+                                   vector<int, 3> fVals) {
+  return select<int,3>(cond0, tVals, fVals); }
+
+// CHECK: %hlsl.select = select <4 x i1>
+// CHECK: ret <4 x i32> %hlsl.select
+int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) {
+  return select(cond0, tVals, fVals); }
 
diff --git a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
index 5637c5f8176e90..68186b1b044a76 100644
--- a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
@@ -1,68 +1,98 @@
-// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+// RUN: %clang_cc1 -finclude-default-header
+// -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only
+// -disable-llvm-passes -verify -verify-ignore-unexpected
 
 int test_no_arg() {
   return select();
   // expected-error at -1 {{no matching function for call to 'select'}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 0 were provided}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 0 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template
+  // not viable: requires 3 arguments, but 0 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
+  // viable: requires 3 arguments, but 0 were provided}}
 }
 
 int test_too_few_args(bool p0) {
   return select(p0);
   // expected-error at -1 {{no matching function for call to 'select'}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 1 was provided}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 1 was provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
+  // viable: requires 3 arguments, but 1 was provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
+  // viable: requires 3 arguments, but 1 was provided}}
 }
 
 int test_too_many_args(bool p0, int t0, int f0, int g0) {
   return select<int>(p0, t0, f0, g0);
   // expected-error at -1 {{no matching function for call to 'select'}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 4 were provided}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 4 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
+  // viable: requires 3 arguments, but 4 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
+  // viable: requires 3 arguments, but 4 were provided}}
 }
 
 int test_select_first_arg_wrong_type(vector<int,1> p0, int t0, int f0) {
   return select(p0, t0, f0);
   // expected-error at -1 {{no matching function for call to 'select'}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: no known conversion from 'vector<int, 1>' (vector of 1 'int' value) to 'bool' for 1st argument}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could not match 'vector<T, Sz>' against 'int'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
+  // viable: no known conversion from 'vector<int, 1>' (vector of 1 'int' value)
+  // to 'bool' for 1st argument}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could
+  // not match 'vector<T, Sz>' against 'int'}}
 }
 
-vector<int,1> test_select_bool_vals_diff_vecs(bool p0, vector<int,1> t0, vector<int,2> f0) {
+vector<int,1> test_select_bool_vals_diff_vecs(bool p0, vector<int,1> t0,
+                                              vector<int,2> f0) {
   return select<vector<int,1> >(p0, t0, f0);
-  // expected-warning at -1 {{implicit conversion truncates vector: 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>' (vector of 1 'int' value)}}
+  // expected-warning at -1 {{implicit conversion truncates vector:
+  // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
+  // (vector of 1 'int' value)}}
 }
 
-vector<int,2> test_select_vector_vals_not_vecs(vector<bool,2> p0, int t0, int f0) {
+vector<int,2> test_select_vector_vals_not_vecs(vector<bool,2> p0, int t0,
+                                               int f0) {
   return select(p0, t0, f0);
   // expected-error at -1 {{no matching function for call to 'select'}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could not match 'vector<T, Sz>' against 'int'}}
-  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: no known conversion from 'vector<bool, 2>' (vector of 2 'bool' values) to 'bool' for 1st argument}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored:
+  // could not match 'vector<T, Sz>' against 'int'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
+  // viable: no known conversion from 'vector<bool, 2>'
+  // (vector of 2 'bool' values) to 'bool' for 1st argument}}
 }
 
-vector<int,1> test_select_vector_vals_wrong_size(vector<bool,2> p0, vector<int,1> t0, vector<int,2> f0) {
+vector<int,1> test_select_vector_vals_wrong_size(vector<bool,2> p0,
+                                                 vector<int,1> t0,
+						 vector<int,2> f0) {
   return select<int,1>(p0, t0, f0); // produce warnings
-  // expected-warning at -1 {{implicit conversion truncates vector: 'vector<bool, 2>' (vector of 2 'bool' values) to 'vector<bool, 1>' (vector of 1 'bool' value)}}
-  // expected-warning at -2 {{implicit conversion truncates vector: 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>' (vector of 1 'int' value)}}
+  // expected-warning at -1 {{implicit conversion truncates vector:
+  // 'vector<bool, 2>' (vector of 2 'bool' values) to 'vector<bool, 1>'
+  // (vector of 1 'bool' value)}}
+  // expected-warning at -2 {{implicit conversion truncates vector:
+  // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
+  // (vector of 1 'int' value)}}
 }
 
 // __builtin_hlsl_select tests
 int test_select_builtin_wrong_arg_count(bool p0, int t0) {
   return __builtin_hlsl_select(p0, t0);
-  // expected-error at -1 {{too few arguments to function call, expected 3, have 2}}
+  // expected-error at -1 {{too few arguments to function call, expected 3,
+  // have 2}}
 }
 
 // not a bool or a vector of bool. should be 2 errors.
 int test_select_builtin_first_arg_wrong_type1(int p0, int t0, int f0) {
   return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error at -1 {{passing 'int' to parameter of incompatible type 'bool'}}
-  // expected-error at -2 {{First argument to __builtin_hlsl_select must be of vector type}}
+  // expected-error at -1 {{passing 'int' to parameter of incompatible type
+  // 'bool'}}
+  // expected-error at -2 {{First argument to __builtin_hlsl_select must be of
+  // vector type}}
   }
 
-int test_select_builtin_first_arg_wrong_type2(vector<int,1> p0, int t0, int f0) {
+int test_select_builtin_first_arg_wrong_type2(vector<int,1> p0, int t0,
+                                              int f0) {
   return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error at -1 {{passing 'vector<int, 1>' (vector of 1 'int' value) to parameter of incompatible type 'bool'}}
-  // expected-error at -2 {{First argument to __builtin_hlsl_select must be of vector type}}
+  // expected-error at -1 {{passing 'vector<int, 1>' (vector of 1 'int' value) to
+  // parameter of incompatible type 'bool'}}
+  // expected-error at -2 {{First argument to __builtin_hlsl_select must be of
+  // vector type}}
 }
 
 // if a bool last 2 args are of same type
@@ -72,19 +102,28 @@ int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) {
 }
 
 // if a vector second arg isnt a vector
-vector<int,2> test_select_builtin_second_arg_not_vector(vector<bool,2> p0, int t0, vector<int,2> f0) {
+vector<int,2> test_select_builtin_second_arg_not_vector(vector<bool,2> p0,
+                                                        int t0,
+							vector<int,2> f0) {
   return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error at -1 {{Second argument to __builtin_hlsl_select must be of vector type}}
+  // expected-error at -1 {{Second argument to __builtin_hlsl_select must be of
+  // vector type}}
 }
 
 // if a vector third arg isn't a vector
-vector<int,2> test_select_builtin_second_arg_not_vector(vector<bool,2> p0, vector<int,2> t0, int f0) {
+vector<int,2> test_select_builtin_second_arg_not_vector(vector<bool,2> p0,
+                                                        vector<int,2> t0,
+							int f0) {
   return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error at -1 {{Third argument to __builtin_hlsl_select must be of vector type}}
+  // expected-error at -1 {{Third argument to __builtin_hlsl_select must be of
+  // vector type}}
 }
 
 // if vector last 2 aren't same type (so both are vectors but wrong type)
-vector<int,2> test_select_builtin_diff_types(vector<bool,1> p0, vector<int,1> t0, vector<float,1> f0) {
+vector<int,2> test_select_builtin_diff_types(vector<bool,1> p0,
+                                             vector<int,1> t0,
+					     vector<float,1> f0) {
   return __builtin_hlsl_select(p0, t0, f0);
-  // expected-error at -1 {{arguments are of different types ('vector<int, [...]>' vs 'vector<float, [...]>')}}
+  // expected-error at -1 {{arguments are of different types ('vector<int, [...]>'
+  // vs 'vector<float, [...]>')}}
 }
\ No newline at end of file

>From dc67f0c499a1751386c10a0f4a1f850da575e652 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 5 Sep 2024 20:53:39 +0000
Subject: [PATCH 7/9] remove use of vector from hlsl test code and use
 corresponding types such as int2

---
 clang/test/CodeGenHLSL/builtins/select.hlsl   | 14 ++++------
 .../test/SemaHLSL/BuiltIns/select-errors.hlsl | 28 ++++++-------------
 2 files changed, 14 insertions(+), 28 deletions(-)

diff --git a/clang/test/CodeGenHLSL/builtins/select.hlsl b/clang/test/CodeGenHLSL/builtins/select.hlsl
index 2d5a49ab2090b9..56457d82ec19e6 100644
--- a/clang/test/CodeGenHLSL/builtins/select.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/select.hlsl
@@ -16,26 +16,22 @@ struct S test_select_infer(bool cond0, struct S tVal, struct S fVal) {
 
 // CHECK: %hlsl.select = select i1
 // CHECK: ret <2 x i32> %hlsl.select
-vector<int,2> test_select_bool_vector(bool cond0, vector<int, 2> tVal,
-                                      vector<int, 2> fVal) {
-  return select<vector<int,2> >(cond0, tVal, fVal); }
+int2 test_select_bool_vector(bool cond0, int2 tVal, int2 fVal) {
+  return select<int2>(cond0, tVal, fVal); }
 
 // CHECK: %hlsl.select = select <1 x i1>
 // CHECK: ret <1 x i32> %hlsl.select
-vector<int,1> test_select_vector_1(vector<bool,1> cond0, vector<int,1> tVals,
-                                   vector<int,1> fVals) {
+int1 test_select_vector_2(bool1 cond0, int1 tVals, int1 fVals) {
   return select<int,1>(cond0, tVals, fVals); }
 
 // CHECK: %hlsl.select = select <2 x i1>
 // CHECK: ret <2 x i32> %hlsl.select
-vector<int,2> test_select_vector_2(vector<bool, 2> cond0, vector<int, 2> tVals,
-                                   vector<int, 2> fVals) {
+int2 test_select_vector_2(bool2 cond0, int2 tVals, int2 fVals) {
   return select<int,2>(cond0, tVals, fVals); }
 
 // CHECK: %hlsl.select = select <3 x i1>
 // CHECK: ret <3 x i32> %hlsl.select
-vector<int,3> test_select_vector_3(vector<bool, 3> cond0, vector<int, 3> tVals,
-                                   vector<int, 3> fVals) {
+int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) {
   return select<int,3>(cond0, tVals, fVals); }
 
 // CHECK: %hlsl.select = select <4 x i1>
diff --git a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
index 68186b1b044a76..eeba500a1d5297 100644
--- a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
@@ -29,7 +29,7 @@ int test_too_many_args(bool p0, int t0, int f0, int g0) {
   // viable: requires 3 arguments, but 4 were provided}}
 }
 
-int test_select_first_arg_wrong_type(vector<int,1> p0, int t0, int f0) {
+int test_select_first_arg_wrong_type(int1 p0, int t0, int f0) {
   return select(p0, t0, f0);
   // expected-error at -1 {{no matching function for call to 'select'}}
   // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
@@ -39,15 +39,14 @@ int test_select_first_arg_wrong_type(vector<int,1> p0, int t0, int f0) {
   // not match 'vector<T, Sz>' against 'int'}}
 }
 
-vector<int,1> test_select_bool_vals_diff_vecs(bool p0, vector<int,1> t0,
-                                              vector<int,2> f0) {
-  return select<vector<int,1> >(p0, t0, f0);
+int1 test_select_bool_vals_diff_vecs(bool p0, int1 t0, int1 f0) {
+  return select<int1>(p0, t0, f0);
   // expected-warning at -1 {{implicit conversion truncates vector:
   // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
   // (vector of 1 'int' value)}}
 }
 
-vector<int,2> test_select_vector_vals_not_vecs(vector<bool,2> p0, int t0,
+int2 test_select_vector_vals_not_vecs(bool2 p0, int t0,
                                                int f0) {
   return select(p0, t0, f0);
   // expected-error at -1 {{no matching function for call to 'select'}}
@@ -58,9 +57,7 @@ vector<int,2> test_select_vector_vals_not_vecs(vector<bool,2> p0, int t0,
   // (vector of 2 'bool' values) to 'bool' for 1st argument}}
 }
 
-vector<int,1> test_select_vector_vals_wrong_size(vector<bool,2> p0,
-                                                 vector<int,1> t0,
-						 vector<int,2> f0) {
+int1 test_select_vector_vals_wrong_size(bool2 p0, int1 t0, int1 f0) {
   return select<int,1>(p0, t0, f0); // produce warnings
   // expected-warning at -1 {{implicit conversion truncates vector:
   // 'vector<bool, 2>' (vector of 2 'bool' values) to 'vector<bool, 1>'
@@ -86,8 +83,7 @@ int test_select_builtin_first_arg_wrong_type1(int p0, int t0, int f0) {
   // vector type}}
   }
 
-int test_select_builtin_first_arg_wrong_type2(vector<int,1> p0, int t0,
-                                              int f0) {
+int test_select_builtin_first_arg_wrong_type2(int1 p0, int t0, int f0) {
   return __builtin_hlsl_select(p0, t0, f0);
   // expected-error at -1 {{passing 'vector<int, 1>' (vector of 1 'int' value) to
   // parameter of incompatible type 'bool'}}
@@ -102,27 +98,21 @@ int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) {
 }
 
 // if a vector second arg isnt a vector
-vector<int,2> test_select_builtin_second_arg_not_vector(vector<bool,2> p0,
-                                                        int t0,
-							vector<int,2> f0) {
+int2 test_select_builtin_second_arg_not_vector(bool2 p0, int t0, int2 f0) {
   return __builtin_hlsl_select(p0, t0, f0);
   // expected-error at -1 {{Second argument to __builtin_hlsl_select must be of
   // vector type}}
 }
 
 // if a vector third arg isn't a vector
-vector<int,2> test_select_builtin_second_arg_not_vector(vector<bool,2> p0,
-                                                        vector<int,2> t0,
-							int f0) {
+int2 test_select_builtin_second_arg_not_vector(bool2 p0, int2 t0, int f0) {
   return __builtin_hlsl_select(p0, t0, f0);
   // expected-error at -1 {{Third argument to __builtin_hlsl_select must be of
   // vector type}}
 }
 
 // if vector last 2 aren't same type (so both are vectors but wrong type)
-vector<int,2> test_select_builtin_diff_types(vector<bool,1> p0,
-                                             vector<int,1> t0,
-					     vector<float,1> f0) {
+int2 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) {
   return __builtin_hlsl_select(p0, t0, f0);
   // expected-error at -1 {{arguments are of different types ('vector<int, [...]>'
   // vs 'vector<float, [...]>')}}

>From 44685a3cf034a9e5d7166b93b8021724bd8fd146 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Sep 2024 15:16:32 +0000
Subject: [PATCH 8/9] address PR comments

---
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  2 +-
 clang/lib/Sema/SemaHLSL.cpp                   |  4 +-
 clang/test/CodeGenHLSL/builtins/select.hlsl   | 63 +++++++++++--------
 .../test/SemaHLSL/BuiltIns/select-errors.hlsl |  2 +-
 4 files changed, 42 insertions(+), 29 deletions(-)

diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index d405d604cf2ab5..1cd4f085cbec09 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1627,7 +1627,7 @@ T select(bool, T, T);
 
 template <typename T, int Sz>
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
-vector<T,Sz> select(vector<bool,Sz>, vector<T,Sz>, vector<T,Sz>);
+vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>);
   
 //===----------------------------------------------------------------------===//
 // sin builtins
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 803d63e31b8321..62661aae31b879 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1553,7 +1553,7 @@ static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
   if (!Arg1->getType()->isVectorType()) {
     S->Diag(Arg1->getBeginLoc(),
 	    diag::err_builtin_non_vector_type)
-      << "Second" << "__builtin_hlsl_select" << Arg1->getType()
+      << "Second" << TheCall->getDirectCallee() << Arg1->getType()
       << Arg1->getSourceRange();
     return true;
   }
@@ -1561,7 +1561,7 @@ static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
   if (!Arg2->getType()->isVectorType()) {
     S->Diag(Arg2->getBeginLoc(),
 	    diag::err_builtin_non_vector_type)
-      << "Third" << "__builtin_hlsl_select" << Arg2->getType()
+      << "Third" << TheCall->getDirectCallee() << Arg2->getType()
       << Arg2->getSourceRange();
     return true;
   }
diff --git a/clang/test/CodeGenHLSL/builtins/select.hlsl b/clang/test/CodeGenHLSL/builtins/select.hlsl
index 56457d82ec19e6..cade938b71a2ba 100644
--- a/clang/test/CodeGenHLSL/builtins/select.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/select.hlsl
@@ -2,40 +2,53 @@
 // RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
 // RUN:   -o - | FileCheck %s --check-prefixes=CHECK
 
-// CHECK: %hlsl.select = select i1
-// CHECK: ret i32 %hlsl.select
+// CHECK-LABEL: test_select_bool_int
+// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, i32 {{%.*}}, i32 {{%.*}}
+// CHECK: ret i32 [[SELECT]]
 int test_select_bool_int(bool cond0, int tVal, int fVal) {
-  return select<int>(cond0, tVal, fVal); }
+  return select<int>(cond0, tVal, fVal);
+}
 
 struct S { int a; };
-// CHECK: %hlsl.select = select i1
-// CHECK: store ptr %hlsl.select
+// CHECK-LABEL: test_select_infer
+// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, ptr {{%.*}}, ptr {{%.*}}
+// CHECK: store ptr [[SELECT]]
 // CHECK: ret void
 struct S test_select_infer(bool cond0, struct S tVal, struct S fVal) {
-  return select(cond0, tVal, fVal); }
+  return select(cond0, tVal, fVal);
+}
 
-// CHECK: %hlsl.select = select i1
-// CHECK: ret <2 x i32> %hlsl.select
+// CHECK-LABEL: test_select_bool_vector
+// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, <2 x i32> {{%.*}}, <2 x i32> {{%.*}}
+// CHECK: ret <2 x i32> [[SELECT]]
 int2 test_select_bool_vector(bool cond0, int2 tVal, int2 fVal) {
-  return select<int2>(cond0, tVal, fVal); }
-
-// CHECK: %hlsl.select = select <1 x i1>
-// CHECK: ret <1 x i32> %hlsl.select
-int1 test_select_vector_2(bool1 cond0, int1 tVals, int1 fVals) {
-  return select<int,1>(cond0, tVals, fVals); }
-
-// CHECK: %hlsl.select = select <2 x i1>
-// CHECK: ret <2 x i32> %hlsl.select
+  return select<int2>(cond0, tVal, fVal);
+}
+
+// CHECK-LABEL: test_select_vector_1
+// CHECK: [[SELECT:%.*]] = select <1 x i1> {{%.*}}, <1 x i32> {{%.*}}, <1 x i32> {{%.*}}
+// CHECK: ret <1 x i32> [[SELECT]]
+int1 test_select_vector_1(bool1 cond0, int1 tVals, int1 fVals) {
+  return select<int,1>(cond0, tVals, fVals);
+}
+
+// CHECK-LABEL: test_select_vector_2
+// CHECK: [[SELECT:%.*]] = select <2 x i1> {{%.*}}, <2 x i32> {{%.*}}, <2 x i32> {{%.*}}
+// CHECK: ret <2 x i32> [[SELECT]]
 int2 test_select_vector_2(bool2 cond0, int2 tVals, int2 fVals) {
-  return select<int,2>(cond0, tVals, fVals); }
+  return select<int,2>(cond0, tVals, fVals);
+}
 
-// CHECK: %hlsl.select = select <3 x i1>
-// CHECK: ret <3 x i32> %hlsl.select
+// CHECK-LABEL: test_select_vector_3
+// CHECK: [[SELECT:%.*]] = select <3 x i1> {{%.*}}, <3 x i32> {{%.*}}, <3 x i32> {{%.*}}
+// CHECK: ret <3 x i32> [[SELECT]]
 int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) {
-  return select<int,3>(cond0, tVals, fVals); }
+  return select<int,3>(cond0, tVals, fVals);
+}
 
-// CHECK: %hlsl.select = select <4 x i1>
-// CHECK: ret <4 x i32> %hlsl.select
+// CHECK-LABEL: test_select_vector_4
+// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> {{%.*}}, <4 x i32> {{%.*}}
+// CHECK: ret <4 x i32> [[SELECT]]
 int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) {
-  return select(cond0, tVals, fVals); }
-
+  return select(cond0, tVals, fVals);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
index eeba500a1d5297..34b5fb6d54cd57 100644
--- a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
@@ -116,4 +116,4 @@ int2 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) {
   return __builtin_hlsl_select(p0, t0, f0);
   // expected-error at -1 {{arguments are of different types ('vector<int, [...]>'
   // vs 'vector<float, [...]>')}}
-}
\ No newline at end of file
+}

>From bdafa1e0387a19e365654fda214d9029ac7ed4f7 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Sat, 7 Sep 2024 01:45:40 +0000
Subject: [PATCH 9/9] make clang-format happy

---
 clang/lib/CodeGen/CGBuiltin.cpp          | 22 ++++++----
 clang/lib/CodeGen/CodeGenFunction.h      |  2 +-
 clang/lib/Headers/hlsl/hlsl_intrinsics.h |  4 +-
 clang/lib/Sema/SemaHLSL.cpp              | 52 +++++++++++-------------
 4 files changed, 40 insertions(+), 40 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 97e56b411bff46..79f09bfd0f13d1 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18521,7 +18521,7 @@ Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
 
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
                                             const CallExpr *E,
-					    ReturnValueSlot ReturnValue) {
+                                            ReturnValueSlot ReturnValue) {
   if (!getLangOpts().HLSL)
     return nullptr;
 
@@ -18711,17 +18711,21 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
   case Builtin::BI__builtin_hlsl_select: {
     Value *OpCond = EmitScalarExpr(E->getArg(0));
     RValue RValTrue = EmitAnyExpr(E->getArg(1));
-    Value *OpTrue = RValTrue.isScalar() ? RValTrue.getScalarVal()
-      : RValTrue.getAggregatePointer(E->getArg(1)->getType(), *this);
+    Value *OpTrue =
+        RValTrue.isScalar()
+            ? RValTrue.getScalarVal()
+            : RValTrue.getAggregatePointer(E->getArg(1)->getType(), *this);
     RValue RValFalse = EmitAnyExpr(E->getArg(2));
-    Value *OpFalse = RValFalse.isScalar() ? RValFalse.getScalarVal()
-      : RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
-    
-    Value *SelectVal = Builder.CreateSelect(OpCond, OpTrue, OpFalse,
-					    "hlsl.select");
+    Value *OpFalse =
+        RValFalse.isScalar()
+            ? RValFalse.getScalarVal()
+            : RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
+
+    Value *SelectVal =
+        Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
     if (!RValTrue.isScalar())
       Builder.CreateStore(SelectVal, ReturnValue.getAddress(),
-			  ReturnValue.isVolatile());
+                          ReturnValue.isVolatile());
 
     return SelectVal;
   }
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 81a0b666183cc7..35a5275e61211d 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4701,7 +4701,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
-				   ReturnValueSlot ReturnValue);
+                                   ReturnValueSlot ReturnValue);
   llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
                                            const CallExpr *E);
   llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 1cd4f085cbec09..f7e37511cbe4ef 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1616,7 +1616,7 @@ double4 saturate(double4);
 template <typename T>
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
 T select(bool, T, T);
-  
+
 /// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
 ///                         vector<T,Sz> FalseVals)
 /// \brief ternary operator for vectors. All vectors must be the same size.
@@ -1628,7 +1628,7 @@ T select(bool, T, T);
 template <typename T, int Sz>
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
 vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>);
-  
+
 //===----------------------------------------------------------------------===//
 // sin builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 62661aae31b879..49482fdfd0b1e5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1513,17 +1513,17 @@ void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
 }
 
 static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
-				unsigned ArgIndex) {
+                                unsigned ArgIndex) {
   assert(TheCall->getNumArgs() >= ArgIndex);
   QualType ArgType = TheCall->getArg(ArgIndex)->getType();
   auto *VTy = ArgType->getAs<VectorType>();
   // not the scalar or vector<scalar>
   if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
-	(VTy && S->Context.hasSameUnqualifiedType(VTy->getElementType(),
-						  Scalar)))) {
+        (VTy &&
+         S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) {
     S->Diag(TheCall->getArg(0)->getBeginLoc(),
-		   diag::err_typecheck_expect_scalar_or_vector)
-	<< ArgType << Scalar;
+            diag::err_typecheck_expect_scalar_or_vector)
+        << ArgType << Scalar;
     return true;
   }
   return false;
@@ -1533,12 +1533,11 @@ static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() == 3);
   Expr *Arg1 = TheCall->getArg(1);
   Expr *Arg2 = TheCall->getArg(2);
-  if(!S->Context.hasSameUnqualifiedType(Arg1->getType(),
-					Arg2->getType())) {
+  if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
     S->Diag(TheCall->getBeginLoc(),
-	    diag::err_typecheck_call_different_arg_types)
-      << Arg1->getType() << Arg2->getType()
-      << Arg1->getSourceRange() << Arg2->getSourceRange();
+            diag::err_typecheck_call_different_arg_types)
+        << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
+        << Arg2->getSourceRange();
     return true;
   }
 
@@ -1551,27 +1550,24 @@ static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
   Expr *Arg1 = TheCall->getArg(1);
   Expr *Arg2 = TheCall->getArg(2);
   if (!Arg1->getType()->isVectorType()) {
-    S->Diag(Arg1->getBeginLoc(),
-	    diag::err_builtin_non_vector_type)
-      << "Second" << TheCall->getDirectCallee() << Arg1->getType()
-      << Arg1->getSourceRange();
+    S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type)
+        << "Second" << TheCall->getDirectCallee() << Arg1->getType()
+        << Arg1->getSourceRange();
     return true;
   }
-  
+
   if (!Arg2->getType()->isVectorType()) {
-    S->Diag(Arg2->getBeginLoc(),
-	    diag::err_builtin_non_vector_type)
-      << "Third" << TheCall->getDirectCallee() << Arg2->getType()
-      << Arg2->getSourceRange();
+    S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type)
+        << "Third" << TheCall->getDirectCallee() << Arg2->getType()
+        << Arg2->getSourceRange();
     return true;
   }
 
-  if (!S->Context.hasSameUnqualifiedType(Arg1->getType(),
-					 Arg2->getType())) {
+  if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
     S->Diag(TheCall->getBeginLoc(),
-	    diag::err_typecheck_call_different_arg_types)
-      << Arg1->getType() << Arg2->getType()
-      << Arg1->getSourceRange() << Arg2->getSourceRange();
+            diag::err_typecheck_call_different_arg_types)
+        << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
+        << Arg2->getSourceRange();
     return true;
   }
 
@@ -1580,9 +1576,9 @@ static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
   if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
       Arg1->getType()->getAs<VectorType>()->getNumElements()) {
     S->Diag(TheCall->getBeginLoc(),
-	    diag::err_typecheck_vector_lengths_not_equal)
-      << TheCall->getArg(0)->getType() << Arg1->getType()
-      << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
+            diag::err_typecheck_vector_lengths_not_equal)
+        << TheCall->getArg(0)->getType() << Arg1->getType()
+        << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
     return true;
   }
   TheCall->setType(Arg1->getType());
@@ -1631,7 +1627,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     auto *VTy = ArgTy->getAs<VectorType>();
     if (VTy && VTy->getElementType()->isBooleanType() &&
-	CheckVectorSelect(&SemaRef, TheCall))
+        CheckVectorSelect(&SemaRef, TheCall))
       return true;
     break;
   }



More information about the cfe-commits mailing list