[clang] [HLSL] Implement support for HLSL intrinsic - select (PR #107129)

via cfe-commits cfe-commits at lists.llvm.org
Tue Sep 3 10:03:01 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang-codegen

Author: Sarah Spall (spall)

<details>
<summary>Changes</summary>

Implement support for HLSL intrinsic select.
This would close issue #<!-- -->75377 

---
Full diff: https://github.com/llvm/llvm-project/pull/107129.diff


6 Files Affected:

- (modified) clang/include/clang/Basic/Builtins.td (+6) 
- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+44) 
- (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+24) 
- (modified) clang/lib/Sema/SemaHLSL.cpp (+84) 
- (added) clang/test/CodeGenHLSL/builtins/select.hlsl (+76) 
- (added) clang/test/SemaHLSL/BuiltIns/select-errors.hlsl (+90) 


``````````diff
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index ac33672a32b336..291f417a7243d6 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4751,6 +4751,12 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_select"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 2a733e4d834cfa..d625a631c4b6a4 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18695,6 +18695,50 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},
         nullptr, "hlsl.saturate");
   }
+  case Builtin::BI__builtin_hlsl_select: {
+    Value *OpCond = EmitScalarExpr(E->getArg(0));
+    Value *OpTrue = EmitScalarExpr(E->getArg(1));
+    Value *OpFalse = EmitScalarExpr(E->getArg(2));
+    llvm::Type *TCond = OpCond->getType();
+
+    // if cond is a bool emit a select instruction
+    if (TCond->isIntegerTy(1))
+      return Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
+
+    // if cond is a vector of bools lower to a shufflevector
+    // todo check if that true and false are vectors
+    // todo check that the size of true and false and cond are the same
+    if (TCond->isVectorTy() &&
+	E->getArg(0)->getType()->getAs<VectorType>()->getElementType()->isBooleanType()) {
+      assert(OpTrue->getType()->isVectorTy() && OpFalse->getType()->isVectorTy() &&
+	     "Select's second and third operands must be vectors if first operand is a vector.");
+
+      auto *VecTyTrue = E->getArg(1)->getType()->getAs<VectorType>();
+      auto *VecTyFalse = E->getArg(2)->getType()->getAs<VectorType>();
+
+      assert(VecTyTrue->getElementType() == VecTyFalse->getElementType() &&
+	     "Select's second and third vectors need the same element types.");
+
+      const unsigned N = VecTyTrue->getNumElements();
+      assert(N == VecTyFalse->getNumElements() &&
+	     N == E->getArg(0)->getType()->getAs<VectorType>()->getNumElements() &&
+	     "Select requires vectors to be of the same size.");
+
+      llvm::Value *Result = llvm::PoisonValue::get(llvm::FixedVectorType::get(IntTy, N));
+      for (unsigned I = 0; I < N; I++) {
+	Value *Index = ConstantInt::get(IntTy, I);
+	Value *IndexBool = Builder.CreateExtractElement(OpCond, Index);
+	Value *TVal = Builder.CreateExtractElement(OpTrue, Index);
+	Value *FVal = Builder.CreateExtractElement(OpFalse, Index);
+	Value *IndexSelect = Builder.CreateSelect(IndexBool, TVal, FVal);
+	Result = Builder.CreateInsertElement(Result, IndexSelect, Index); 
+      }
+
+      return Result;
+    }
+    
+    llvm_unreachable("Select requires a bool or vector of bools as its first operand.");
+  }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
     return EmitRuntimeCall(CGM.CreateRuntimeFunction(
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 6d38b668fe770e..da14e837f7bd0a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1603,6 +1603,30 @@ double3 saturate(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
 double4 saturate(double4);
 
+//===----------------------------------------------------------------------===//
+// select builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T select(bool Cond, T TrueVal, T FalseVal)
+/// \brief ternary operator.
+/// \param Cond The Condition input value.
+/// \param TrueVal The Value returned if Cond is true.
+/// \param FalseVal The Value returned if Cond is false.
+
+template <typename T>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+T select(bool, T, T);
+  
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals, vector<T,Sz>, FalseVals)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVals The vector values are chosen from when conditions are true.
+/// \param FalseVals The vector values are chosen from when conditions are false.
+
+template <typename T, int Sz>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+vector<T,Sz> select(vector<bool,Sz>, vector<T,Sz>, vector<T,Sz>);
+  
 //===----------------------------------------------------------------------===//
 // sin builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index df01549cc2eeb6..8190ed14f8dd49 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1013,6 +1013,66 @@ void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
   TheCall->setType(ReturnType);
 }
 
+bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
+  assert(TheCall->getNumArgs() == 3);
+  Expr *Arg1 = TheCall->getArg(1);
+  Expr *Arg2 = TheCall->getArg(2);
+  if(!S->Context.hasSameUnqualifiedType(Arg1->getType(),
+					Arg2->getType())) {
+    S->Diag(TheCall->getBeginLoc(),
+	    diag::err_typecheck_call_different_arg_types)
+      << Arg1->getType() << Arg2->getType()
+      << Arg1->getSourceRange() << Arg2->getSourceRange();
+    return true;
+  }
+
+  TheCall->setType(Arg1->getType());
+  return false;
+}
+
+bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
+  assert(TheCall->getNumArgs() == 3);
+  Expr *Arg1 = TheCall->getArg(1);
+  Expr *Arg2 = TheCall->getArg(2);
+  if(!Arg1->getType()->isVectorType()) {
+    S->Diag(Arg1->getBeginLoc(),
+	    diag::err_builtin_non_vector_type)
+      << "Second" << "__builtin_hlsl_select" << Arg1->getType()
+      << Arg1->getSourceRange();
+    return true;
+  }
+  
+  if(!Arg2->getType()->isVectorType()) {
+    S->Diag(Arg2->getBeginLoc(),
+	    diag::err_builtin_non_vector_type)
+      << "Third" << "__builtin_hlsl_select" << Arg2->getType()
+      << Arg2->getSourceRange();
+    return true;
+  }
+
+  if(!S->Context.hasSameUnqualifiedType(Arg1->getType(),
+					Arg2->getType())) {
+    S->Diag(TheCall->getBeginLoc(),
+	    diag::err_typecheck_call_different_arg_types)
+      << Arg1->getType() << Arg2->getType()
+      << Arg1->getSourceRange() << Arg2->getSourceRange();
+    return true;
+  }
+
+  // caller has checked that Arg0 is a vector.
+  // check all three args have the same length.
+  if(TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
+     Arg1->getType()->getAs<VectorType>()->getNumElements()) {
+    S->Diag(TheCall->getBeginLoc(),
+	    diag::err_typecheck_vector_lengths_not_equal)
+      << TheCall->getArg(0)->getType() << Arg1->getType()
+      << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
+    return true;
+  }
+
+  return false;
+}
+
 // Note: returning true in this case results in CheckBuiltinFunctionCall
 // returning an ExprError
 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -1046,6 +1106,30 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     break;
   }
   case Builtin::BI__builtin_hlsl_elementwise_saturate:
+  case Builtin::BI__builtin_hlsl_select: {
+    if (SemaRef.checkArgCount(TheCall, 3))
+      return true;
+    QualType ArgTy = TheCall->getArg(0)->getType();
+    if (ArgTy->isBooleanType()) {
+      if (CheckBoolSelect(&SemaRef, TheCall))
+	return true;
+    } else if (ArgTy->isVectorType() &&
+	       ArgTy->getAs<VectorType>()->getElementType()->isBooleanType()) {
+      if (CheckVectorSelect(&SemaRef, TheCall))
+	return true;
+    } else { // first operand is not a bool or a vector of bools.
+      SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
+		   diag::err_typecheck_convert_incompatible)
+	<< TheCall->getArg(0)->getType() << getASTContext().BoolTy
+	<< 1 << 0 << 0;
+      SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
+		   diag::err_builtin_non_vector_type)
+	<< "First" << "__builtin_hlsl_select" << TheCall->getArg(0)->getType()
+	<< TheCall->getArg(0)->getSourceRange();
+      return true;
+    }
+    break;
+  }
   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
     if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/select.hlsl b/clang/test/CodeGenHLSL/builtins/select.hlsl
new file mode 100644
index 00000000000000..eeafd8ea852ea3
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/select.hlsl
@@ -0,0 +1,76 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK
+
+// CHECK: %hlsl.select = select i1
+// CHECK: ret i32 %hlsl.select
+int test_select_bool_int(bool cond0, int tVal, int fVal) { return select<int>(cond0, tVal, fVal); }
+
+// CHECK: %hlsl.select = select i1
+// CHECK: ret <2 x i32> %hlsl.select
+vector<int,2> test_select_bool_vector(bool cond0, vector<int, 2> tVal, vector<int, 2> fVal) { return select<vector<int,2> >(cond0, tVal, fVal); }
+
+// CHECK: %4 = extractelement <1 x i1> %extractvec, i32 0
+// CHECK: %5 = extractelement <1 x i32> %2, i32 0
+// CHECK: %6 = extractelement <1 x i32> %3, i32 0
+// CHECK: %7 = select i1 %4, i32 %5, i32 %6
+// CHECK: %8 = insertelement <1 x i32> poison, i32 %7, i32 0
+// CHECK: ret <1 x i32> %8
+vector<int,1> test_select_vector_1(vector<bool,1> cond0, vector<int,1> tVals, vector<int,1> fVals) { return select<int,1>(cond0, tVals, fVals); }
+
+// CHECK: %4 = extractelement <2 x i1> %extractvec, i32 0
+// CHECK: %5 = extractelement <2 x i32> %2, i32 0
+// CHECK: %6 = extractelement <2 x i32> %3, i32 0
+// CHECK: %7 = select i1 %4, i32 %5, i32 %6
+// CHECK: %8 = insertelement <2 x i32> poison, i32 %7, i32 0
+// CHECK: %9 = extractelement <2 x i1> %extractvec, i32 1
+// CHECK: %10 = extractelement <2 x i32> %2, i32 1
+// CHECK: %11 = extractelement <2 x i32> %3, i32 1
+// CHECK: %12 = select i1 %9, i32 %10, i32 %11
+// CHECK: %13 = insertelement <2 x i32> %8, i32 %12, i32 1
+// CHECK: ret <2 x i32> %13
+vector<int,2> test_select_vector_2(vector<bool, 2> cond0, vector<int, 2> tVals, vector<int, 2> fVals) { return select<int,2>(cond0, tVals, fVals); }
+
+// CHECK: %4 = extractelement <3 x i1> %extractvec, i32 0
+// CHECK: %5 = extractelement <3 x i32> %2, i32 0
+// CHECK: %6 = extractelement <3 x i32> %3, i32 0
+// CHECK: %7 = select i1 %4, i32 %5, i32 %6
+// CHECK: %8 = insertelement <3 x i32> poison, i32 %7, i32 0
+// CHECK: %9 = extractelement <3 x i1> %extractvec, i32 1
+// CHECK: %10 = extractelement <3 x i32> %2, i32 1
+// CHECK: %11 = extractelement <3 x i32> %3, i32 1
+// CHECK: %12 = select i1 %9, i32 %10, i32 %11
+// CHECK: %13 = insertelement <3 x i32> %8, i32 %12, i32 1
+// CHECK: %14 = extractelement <3 x i1> %extractvec, i32 2
+// CHECK: %15 = extractelement <3 x i32> %2, i32 2
+// CHECK: %16 = extractelement <3 x i32> %3, i32 2
+// CHECK: %17 = select i1 %14, i32 %15, i32 %16
+// CHECK: %18 = insertelement <3 x i32> %13, i32 %17, i32 2
+// CHECK: ret <3 x i32> %18
+vector<int,3> test_select_vector_3(vector<bool, 3> cond0, vector<int, 3> tVals, vector<int, 3> fVals) { return select<int,3>(cond0, tVals, fVals); }
+
+// CHECK: %4 = extractelement <4 x i1> %extractvec, i32 0
+// CHECK: %5 = extractelement <4 x i32> %2, i32 0
+// CHECK: %6 = extractelement <4 x i32> %3, i32 0
+// CHECK: %7 = select i1 %4, i32 %5, i32 %6
+// CHECK: %8 = insertelement <4 x i32> poison, i32 %7, i32 0
+// CHECK: %9 = extractelement <4 x i1> %extractvec, i32 1
+// CHECK: %10 = extractelement <4 x i32> %2, i32 1
+// CHECK: %11 = extractelement <4 x i32> %3, i32 1
+// CHECK: %12 = select i1 %9, i32 %10, i32 %11
+// CHECK: %13 = insertelement <4 x i32> %8, i32 %12, i32 1
+// CHECK: %14 = extractelement <4 x i1> %extractvec, i32 2
+// CHECK: %15 = extractelement <4 x i32> %2, i32 2
+// CHECK: %16 = extractelement <4 x i32> %3, i32 2
+// CHECK: %17 = select i1 %14, i32 %15, i32 %16
+// CHECK: %18 = insertelement <4 x i32> %13, i32 %17, i32 2
+// CHECK: %19 = extractelement <4 x i1> %extractvec, i32 3
+// CHECK: %20 = extractelement <4 x i32> %2, i32 3
+// CHECK: %21 = extractelement <4 x i32> %3, i32 3
+// CHECK: %22 = select i1 %19, i32 %20, i32 %21
+// CHECK: %23 = insertelement <4 x i32> %18, i32 %22, i32 3
+// CHECK: ret <4 x i32> %23
+vector<int,4> test_select_vector_4(vector<bool, 4> cond0, vector<int, 4> tVals, vector<int, 4> fVals) { return select<int,4>(cond0, tVals, fVals); }
+
+
+
diff --git a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
new file mode 100644
index 00000000000000..5637c5f8176e90
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
@@ -0,0 +1,90 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
+
+int test_no_arg() {
+  return select();
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 0 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 0 were provided}}
+}
+
+int test_too_few_args(bool p0) {
+  return select(p0);
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 1 was provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 1 was provided}}
+}
+
+int test_too_many_args(bool p0, int t0, int f0, int g0) {
+  return select<int>(p0, t0, f0, g0);
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 4 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 4 were provided}}
+}
+
+int test_select_first_arg_wrong_type(vector<int,1> p0, int t0, int f0) {
+  return select(p0, t0, f0);
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: no known conversion from 'vector<int, 1>' (vector of 1 'int' value) to 'bool' for 1st argument}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could not match 'vector<T, Sz>' against 'int'}}
+}
+
+vector<int,1> test_select_bool_vals_diff_vecs(bool p0, vector<int,1> t0, vector<int,2> f0) {
+  return select<vector<int,1> >(p0, t0, f0);
+  // expected-warning at -1 {{implicit conversion truncates vector: 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>' (vector of 1 'int' value)}}
+}
+
+vector<int,2> test_select_vector_vals_not_vecs(vector<bool,2> p0, int t0, int f0) {
+  return select(p0, t0, f0);
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could not match 'vector<T, Sz>' against 'int'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: no known conversion from 'vector<bool, 2>' (vector of 2 'bool' values) to 'bool' for 1st argument}}
+}
+
+vector<int,1> test_select_vector_vals_wrong_size(vector<bool,2> p0, vector<int,1> t0, vector<int,2> f0) {
+  return select<int,1>(p0, t0, f0); // produce warnings
+  // expected-warning at -1 {{implicit conversion truncates vector: 'vector<bool, 2>' (vector of 2 'bool' values) to 'vector<bool, 1>' (vector of 1 'bool' value)}}
+  // expected-warning at -2 {{implicit conversion truncates vector: 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>' (vector of 1 'int' value)}}
+}
+
+// __builtin_hlsl_select tests
+int test_select_builtin_wrong_arg_count(bool p0, int t0) {
+  return __builtin_hlsl_select(p0, t0);
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 2}}
+}
+
+// not a bool or a vector of bool. should be 2 errors.
+int test_select_builtin_first_arg_wrong_type1(int p0, int t0, int f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{passing 'int' to parameter of incompatible type 'bool'}}
+  // expected-error at -2 {{First argument to __builtin_hlsl_select must be of vector type}}
+  }
+
+int test_select_builtin_first_arg_wrong_type2(vector<int,1> p0, int t0, int f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{passing 'vector<int, 1>' (vector of 1 'int' value) to parameter of incompatible type 'bool'}}
+  // expected-error at -2 {{First argument to __builtin_hlsl_select must be of vector type}}
+}
+
+// if a bool last 2 args are of same type
+int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{arguments are of different types ('int' vs 'double')}}
+}
+
+// if a vector second arg isnt a vector
+vector<int,2> test_select_builtin_second_arg_not_vector(vector<bool,2> p0, int t0, vector<int,2> f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{Second argument to __builtin_hlsl_select must be of vector type}}
+}
+
+// if a vector third arg isn't a vector
+vector<int,2> test_select_builtin_second_arg_not_vector(vector<bool,2> p0, vector<int,2> t0, int f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{Third argument to __builtin_hlsl_select must be of vector type}}
+}
+
+// if vector last 2 aren't same type (so both are vectors but wrong type)
+vector<int,2> test_select_builtin_diff_types(vector<bool,1> p0, vector<int,1> t0, vector<float,1> f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{arguments are of different types ('vector<int, [...]>' vs 'vector<float, [...]>')}}
+}
\ No newline at end of file

``````````

</details>


https://github.com/llvm/llvm-project/pull/107129


More information about the cfe-commits mailing list