[clang] [HLSL] select scalar overloads for vector conditions (PR #129396)
via cfe-commits
cfe-commits at lists.llvm.org
Sat Mar 1 11:25:19 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang-codegen
Author: Chris B (llvm-beanz)
<details>
<summary>Changes</summary>
This PR adds scalar/vector overloads for vector conditions to the `select` builtin, and updates the sema checking and codegen to allow scalars to extend to vectors.
Fixes #<!-- -->126570
---
Full diff: https://github.com/llvm/llvm-project/pull/129396.diff
7 Files Affected:
- (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+3)
- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+8)
- (modified) clang/lib/Headers/hlsl/hlsl_detail.h (+5)
- (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+36)
- (modified) clang/lib/Sema/SemaHLSL.cpp (+32-24)
- (modified) clang/test/CodeGenHLSL/builtins/select.hlsl (+29)
- (modified) clang/test/SemaHLSL/BuiltIns/select-errors.hlsl (+22-76)
``````````diff
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index d094c075ecee2..be649f0bce320 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12682,6 +12682,9 @@ def err_hlsl_param_qualifier_mismatch :
def err_hlsl_vector_compound_assignment_truncation : Error<
"left hand operand of type %0 to compound assignment cannot be truncated "
"when used with right hand operand of type %1">;
+def err_hlsl_builtin_scalar_vector_mismatch : Error<
+ "%select{all|second and third}0 arguments to %1 must be of scalar or "
+ "vector type with matching scalar element type%diff{: $ vs $|}2,3">;
def warn_hlsl_impcast_vector_truncation : Warning<
"implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>;
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 03b8d16b76e0d..a84e5e4b59c89 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -19741,6 +19741,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
RValFalse.isScalar()
? RValFalse.getScalarVal()
: RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
+ if (auto *VTy = E->getType()->getAs<VectorType>()) {
+ if (!OpTrue->getType()->isVectorTy())
+ OpTrue =
+ Builder.CreateVectorSplat(VTy->getNumElements(), OpTrue, "splat");
+ if (!OpFalse->getType()->isVectorTy())
+ OpFalse =
+ Builder.CreateVectorSplat(VTy->getNumElements(), OpFalse, "splat");
+ }
Value *SelectVal =
Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 0d568539cd66a..daccd2d793aa8 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -95,6 +95,11 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#endif
}
+template<typename T>
+struct is_arithmetic {
+ static const bool Value = __is_arithmetic(T);
+};
+
} // namespace __detail
} // namespace hlsl
#endif //_HLSL_HLSL_DETAILS_H_
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index ed008eeb04ba8..77a7f773b85b2 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -2246,6 +2246,42 @@ template <typename T, int Sz>
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>);
+
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, T TrueVal,
+/// vector<T,Sz> FalseVals)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVal The scalar value to splat from when conditions are true.
+/// \param FalseVals The vector values are chosen from when conditions are
+/// false.
+
+template <typename T, int Sz>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+vector<T, Sz> select(vector<bool, Sz>, T, vector<T, Sz>);
+
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
+/// T FalseVal)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVals The vector values are chosen from when conditions are true.
+/// \param FalseVal The scalar value to splat from when conditions are false.
+
+template <typename T, int Sz>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, T);
+
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
+/// T FalseVal)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVal The scalar value to splat from when conditions are true.
+/// \param FalseVal The scalar value to splat from when conditions are false.
+
+template <typename T, int Sz>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+__detail::enable_if_t<__detail::is_arithmetic<T>::Value, vector<T, Sz>> select(
+ vector<bool, Sz>, T, T);
+
//===----------------------------------------------------------------------===//
// sin builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index bfe84b16218b7..4ec31cd39eb60 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2213,40 +2213,48 @@ static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
assert(TheCall->getNumArgs() == 3);
Expr *Arg1 = TheCall->getArg(1);
+ QualType Arg1Ty = Arg1->getType();
Expr *Arg2 = TheCall->getArg(2);
- if (!Arg1->getType()->isVectorType()) {
- S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type)
- << "Second" << TheCall->getDirectCallee() << Arg1->getType()
+ QualType Arg2Ty = Arg2->getType();
+
+ QualType Arg1ScalarTy = Arg1Ty;
+ if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
+ Arg1ScalarTy = VTy->getElementType();
+
+ QualType Arg2ScalarTy = Arg2Ty;
+ if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
+ Arg2ScalarTy = VTy->getElementType();
+
+ if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
+ S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
+ << /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
+
+ QualType Arg0Ty = TheCall->getArg(0)->getType();
+ unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
+ unsigned Arg1Length = Arg1Ty->isVectorType()
+ ? Arg1Ty->getAs<VectorType>()->getNumElements()
+ : 0;
+ unsigned Arg2Length = Arg2Ty->isVectorType()
+ ? Arg2Ty->getAs<VectorType>()->getNumElements()
+ : 0;
+ if (Arg1Length > 0 && Arg0Length != Arg1Length) {
+ S->Diag(TheCall->getBeginLoc(),
+ diag::err_typecheck_vector_lengths_not_equal)
+ << Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
<< Arg1->getSourceRange();
return true;
}
- if (!Arg2->getType()->isVectorType()) {
- S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type)
- << "Third" << TheCall->getDirectCallee() << Arg2->getType()
- << Arg2->getSourceRange();
- return true;
- }
-
- if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
+ if (Arg2Length > 0 && Arg0Length != Arg2Length) {
S->Diag(TheCall->getBeginLoc(),
- diag::err_typecheck_call_different_arg_types)
- << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
+ diag::err_typecheck_vector_lengths_not_equal)
+ << Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
<< Arg2->getSourceRange();
return true;
}
- // caller has checked that Arg0 is a vector.
- // check all three args have the same length.
- if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
- Arg1->getType()->getAs<VectorType>()->getNumElements()) {
- S->Diag(TheCall->getBeginLoc(),
- diag::err_typecheck_vector_lengths_not_equal)
- << TheCall->getArg(0)->getType() << Arg1->getType()
- << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
- return true;
- }
- TheCall->setType(Arg1->getType());
+ TheCall->setType(
+ S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
return false;
}
diff --git a/clang/test/CodeGenHLSL/builtins/select.hlsl b/clang/test/CodeGenHLSL/builtins/select.hlsl
index cade938b71a2b..196b8a90cd877 100644
--- a/clang/test/CodeGenHLSL/builtins/select.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/select.hlsl
@@ -52,3 +52,32 @@ int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) {
int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) {
return select(cond0, tVals, fVals);
}
+
+// CHECK-LABEL: test_select_vector_scalar_vector
+// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
+// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> {{%.*}}
+// CHECK: ret <4 x i32> [[SELECT]]
+int4 test_select_vector_scalar_vector(bool4 cond0, int tVal, int4 fVals) {
+ return select(cond0, tVal, fVals);
+}
+
+// CHECK-LABEL: test_select_vector_vector_scalar
+// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
+// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> {{%.*}}, <4 x i32> [[SPLAT1]]
+// CHECK: ret <4 x i32> [[SELECT]]
+int4 test_select_vector_vector_scalar(bool4 cond0, int4 tVals, int fVal) {
+ return select(cond0, tVals, fVal);
+}
+
+// CHECK-LABEL: test_select_vector_scalar_scalar
+// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
+// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SPLAT_SRC2:%.*]] = insertelement <4 x i32> poison, i32 %3, i64 0
+// CHECK: [[SPLAT2:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC2]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> [[SPLAT2]]
+// CHECK: ret <4 x i32> [[SELECT]]
+int4 test_select_vector_scalar_scalar(bool4 cond0, int tVal, int fVal) {
+ return select(cond0, tVal, fVal);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
index 34b5fb6d54cd5..b445cedcba074 100644
--- a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
@@ -1,119 +1,65 @@
-// RUN: %clang_cc1 -finclude-default-header
-// -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only
-// -disable-llvm-passes -verify -verify-ignore-unexpected
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
-int test_no_arg() {
- return select();
- // expected-error at -1 {{no matching function for call to 'select'}}
- // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template
- // not viable: requires 3 arguments, but 0 were provided}}
- // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
- // viable: requires 3 arguments, but 0 were provided}}
-}
-
-int test_too_few_args(bool p0) {
- return select(p0);
- // expected-error at -1 {{no matching function for call to 'select'}}
- // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
- // viable: requires 3 arguments, but 1 was provided}}
- // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
- // viable: requires 3 arguments, but 1 was provided}}
-}
-
-int test_too_many_args(bool p0, int t0, int f0, int g0) {
- return select<int>(p0, t0, f0, g0);
- // expected-error at -1 {{no matching function for call to 'select'}}
- // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
- // viable: requires 3 arguments, but 4 were provided}}
- // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
- // viable: requires 3 arguments, but 4 were provided}}
-}
int test_select_first_arg_wrong_type(int1 p0, int t0, int f0) {
return select(p0, t0, f0);
- // expected-error at -1 {{no matching function for call to 'select'}}
- // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
- // viable: no known conversion from 'vector<int, 1>' (vector of 1 'int' value)
- // to 'bool' for 1st argument}}
- // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could
- // not match 'vector<T, Sz>' against 'int'}}
}
int1 test_select_bool_vals_diff_vecs(bool p0, int1 t0, int1 f0) {
return select<int1>(p0, t0, f0);
- // expected-warning at -1 {{implicit conversion truncates vector:
- // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
- // (vector of 1 'int' value)}}
}
int2 test_select_vector_vals_not_vecs(bool2 p0, int t0,
int f0) {
return select(p0, t0, f0);
- // expected-error at -1 {{no matching function for call to 'select'}}
- // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored:
- // could not match 'vector<T, Sz>' against 'int'}}
- // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
- // viable: no known conversion from 'vector<bool, 2>'
- // (vector of 2 'bool' values) to 'bool' for 1st argument}}
}
int1 test_select_vector_vals_wrong_size(bool2 p0, int1 t0, int1 f0) {
- return select<int,1>(p0, t0, f0); // produce warnings
- // expected-warning at -1 {{implicit conversion truncates vector:
- // 'vector<bool, 2>' (vector of 2 'bool' values) to 'vector<bool, 1>'
- // (vector of 1 'bool' value)}}
- // expected-warning at -2 {{implicit conversion truncates vector:
- // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
- // (vector of 1 'int' value)}}
+ return select<int,1>(p0, t0, f0); // expected-warning{{implicit conversion truncates vector: 'bool2' (aka 'vector<bool, 2>') to 'vector<bool, 1>' (vector of 1 'bool' value)}}
+}
+
+int test_select_no_args() {
+ return __builtin_hlsl_select(); // expected-error{{too few arguments to function call, expected 3, have 0}}
+}
+
+int test_select_builtin_wrong_arg_count(bool p0) {
+ return __builtin_hlsl_select(p0); // expected-error{{too few arguments to function call, expected 3, have 1}}
}
// __builtin_hlsl_select tests
-int test_select_builtin_wrong_arg_count(bool p0, int t0) {
- return __builtin_hlsl_select(p0, t0);
- // expected-error at -1 {{too few arguments to function call, expected 3,
- // have 2}}
+int test_select_builtin_wrong_arg_count2(bool p0, int t0) {
+ return __builtin_hlsl_select(p0, t0); // expected-error{{too few arguments to function call, expected 3, have 2}}
+}
+
+int test_too_many_args(bool p0, int t0, int f0, int g0) {
+ return __builtin_hlsl_select(p0, t0, f0, g0); // expected-error{{too many arguments to function call, expected 3, have 4}}
}
// not a bool or a vector of bool. should be 2 errors.
int test_select_builtin_first_arg_wrong_type1(int p0, int t0, int f0) {
- return __builtin_hlsl_select(p0, t0, f0);
- // expected-error at -1 {{passing 'int' to parameter of incompatible type
- // 'bool'}}
- // expected-error at -2 {{First argument to __builtin_hlsl_select must be of
- // vector type}}
- }
+ return __builtin_hlsl_select(p0, t0, f0); // expected-error{{invalid operand of type 'int' where 'bool' or a vector of such type is required}}
+}
int test_select_builtin_first_arg_wrong_type2(int1 p0, int t0, int f0) {
- return __builtin_hlsl_select(p0, t0, f0);
- // expected-error at -1 {{passing 'vector<int, 1>' (vector of 1 'int' value) to
- // parameter of incompatible type 'bool'}}
- // expected-error at -2 {{First argument to __builtin_hlsl_select must be of
- // vector type}}
+ return __builtin_hlsl_select(p0, t0, f0); // expected-error{{invalid operand of type 'int1' (aka 'vector<int, 1>') where 'bool' or a vector of such type is required}}
}
// if a bool last 2 args are of same type
int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) {
- return __builtin_hlsl_select(p0, t0, f0);
- // expected-error at -1 {{arguments are of different types ('int' vs 'double')}}
+ return __builtin_hlsl_select(p0, t0, f0); // expected-error{{arguments are of different types ('int' vs 'double')}}
}
// if a vector second arg isnt a vector
int2 test_select_builtin_second_arg_not_vector(bool2 p0, int t0, int2 f0) {
return __builtin_hlsl_select(p0, t0, f0);
- // expected-error at -1 {{Second argument to __builtin_hlsl_select must be of
- // vector type}}
}
// if a vector third arg isn't a vector
int2 test_select_builtin_second_arg_not_vector(bool2 p0, int2 t0, int f0) {
return __builtin_hlsl_select(p0, t0, f0);
- // expected-error at -1 {{Third argument to __builtin_hlsl_select must be of
- // vector type}}
}
// if vector last 2 aren't same type (so both are vectors but wrong type)
-int2 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) {
- return __builtin_hlsl_select(p0, t0, f0);
- // expected-error at -1 {{arguments are of different types ('vector<int, [...]>'
- // vs 'vector<float, [...]>')}}
+int1 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) {
+ return __builtin_hlsl_select(p0, t0, f0); // expected-error{{second and third arguments to __builtin_hlsl_select must be of scalar or vector type with matching scalar element type: 'vector<int, [...]>' vs 'vector<float, [...]>'}}
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/129396
More information about the cfe-commits
mailing list