[clang] 0f349b7 - [HLSL] Implement support for HLSL intrinsic - select (#107129)

via cfe-commits cfe-commits at lists.llvm.org
Mon Sep 9 11:07:24 PDT 2024


Author: Sarah Spall
Date: 2024-09-09T11:07:20-07:00
New Revision: 0f349b7a9cde0080e626f6cfd362885341eb63b4

URL: https://github.com/llvm/llvm-project/commit/0f349b7a9cde0080e626f6cfd362885341eb63b4
DIFF: https://github.com/llvm/llvm-project/commit/0f349b7a9cde0080e626f6cfd362885341eb63b4.diff

LOG: [HLSL] Implement support for HLSL intrinsic - select (#107129)

Implement support for HLSL intrinsic select.
This would close issue #75377

Added: 
    clang/test/CodeGenHLSL/builtins/select.hlsl
    clang/test/SemaHLSL/BuiltIns/select-errors.hlsl

Modified: 
    clang/include/clang/Basic/Builtins.td
    clang/include/clang/Basic/DiagnosticSemaKinds.td
    clang/lib/CodeGen/CGBuiltin.cpp
    clang/lib/CodeGen/CodeGenFunction.h
    clang/lib/Headers/hlsl/hlsl_intrinsics.h
    clang/lib/Sema/SemaHLSL.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 92118418d9d459..d9833b6559eab3 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4763,6 +4763,12 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_select"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];

diff  --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 58819a64813fce..b160fee827a750 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -9206,6 +9206,9 @@ def err_typecheck_expect_scalar_operand : Error<
   "operand of type %0 where arithmetic or pointer type is required">;
 def err_typecheck_cond_incompatible_operands : Error<
   "incompatible operand types%
diff { ($ and $)|}0,1">;
+def err_typecheck_expect_scalar_or_vector : Error<
+  "invalid operand of type %0 where %1 or "
+  "a vector of such type is required">;
 def err_typecheck_expect_flt_or_vector : Error<
   "invalid operand of type %0 where floating, complex or "
   "a vector of such types is required">;

diff  --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index b0dd299edaf9cc..0078ceb7e892af 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -6244,8 +6244,20 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
   }
 
   // EmitHLSLBuiltinExpr will check getLangOpts().HLSL
-  if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E))
-    return RValue::get(V);
+  if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E, ReturnValue)) {
+    switch (EvalKind) {
+    case TEK_Scalar:
+      if (V->getType()->isVoidTy())
+        return RValue::get(nullptr);
+      return RValue::get(V);
+    case TEK_Aggregate:
+      return RValue::getAggregate(ReturnValue.getAddress(),
+                                  ReturnValue.isVolatile());
+    case TEK_Complex:
+      llvm_unreachable("No current hlsl builtin returns complex");
+    }
+    llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr");
+  }
 
   if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice)
     return EmitHipStdParUnsupportedBuiltin(this, FD);
@@ -18640,7 +18652,8 @@ Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
 }
 
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
-                                            const CallExpr *E) {
+                                            const CallExpr *E,
+                                            ReturnValueSlot ReturnValue) {
   if (!getLangOpts().HLSL)
     return nullptr;
 
@@ -18827,6 +18840,27 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},
         nullptr, "hlsl.saturate");
   }
+  case Builtin::BI__builtin_hlsl_select: {
+    Value *OpCond = EmitScalarExpr(E->getArg(0));
+    RValue RValTrue = EmitAnyExpr(E->getArg(1));
+    Value *OpTrue =
+        RValTrue.isScalar()
+            ? RValTrue.getScalarVal()
+            : RValTrue.getAggregatePointer(E->getArg(1)->getType(), *this);
+    RValue RValFalse = EmitAnyExpr(E->getArg(2));
+    Value *OpFalse =
+        RValFalse.isScalar()
+            ? RValFalse.getScalarVal()
+            : RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
+
+    Value *SelectVal =
+        Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
+    if (!RValTrue.isScalar())
+      Builder.CreateStore(SelectVal, ReturnValue.getAddress(),
+                          ReturnValue.isVolatile());
+
+    return SelectVal;
+  }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
     return EmitRuntimeCall(CGM.CreateRuntimeFunction(
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",

diff  --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 5892d6ac6f88a5..4eca770ca35d85 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4704,7 +4704,8 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
-  llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
+  llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
+                                   ReturnValueSlot ReturnValue);
   llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
                                            const CallExpr *E);
   llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);

diff  --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 5c08a45a35377d..2ac18056b0fc3d 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1603,6 +1603,32 @@ double3 saturate(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
 double4 saturate(double4);
 
+//===----------------------------------------------------------------------===//
+// select builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T select(bool Cond, T TrueVal, T FalseVal)
+/// \brief ternary operator.
+/// \param Cond The Condition input value.
+/// \param TrueVal The Value returned if Cond is true.
+/// \param FalseVal The Value returned if Cond is false.
+
+template <typename T>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+T select(bool, T, T);
+
+/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
+///                         vector<T,Sz> FalseVals)
+/// \brief ternary operator for vectors. All vectors must be the same size.
+/// \param Conds The Condition input values.
+/// \param TrueVals The vector values are chosen from when conditions are true.
+/// \param FalseVals The vector values are chosen from when conditions are
+/// false.
+
+template <typename T, int Sz>
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
+vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>);
+
 //===----------------------------------------------------------------------===//
 // sin builtins
 //===----------------------------------------------------------------------===//

diff  --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 3b40769939f12f..3b91303ac8cb8a 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1531,6 +1531,79 @@ void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
   TheCall->setType(ReturnType);
 }
 
+static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
+                                unsigned ArgIndex) {
+  assert(TheCall->getNumArgs() >= ArgIndex);
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
+  auto *VTy = ArgType->getAs<VectorType>();
+  // not the scalar or vector<scalar>
+  if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
+        (VTy &&
+         S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar)))) {
+    S->Diag(TheCall->getArg(0)->getBeginLoc(),
+            diag::err_typecheck_expect_scalar_or_vector)
+        << ArgType << Scalar;
+    return true;
+  }
+  return false;
+}
+
+static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
+  assert(TheCall->getNumArgs() == 3);
+  Expr *Arg1 = TheCall->getArg(1);
+  Expr *Arg2 = TheCall->getArg(2);
+  if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
+    S->Diag(TheCall->getBeginLoc(),
+            diag::err_typecheck_call_
diff erent_arg_types)
+        << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
+        << Arg2->getSourceRange();
+    return true;
+  }
+
+  TheCall->setType(Arg1->getType());
+  return false;
+}
+
+static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
+  assert(TheCall->getNumArgs() == 3);
+  Expr *Arg1 = TheCall->getArg(1);
+  Expr *Arg2 = TheCall->getArg(2);
+  if (!Arg1->getType()->isVectorType()) {
+    S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type)
+        << "Second" << TheCall->getDirectCallee() << Arg1->getType()
+        << Arg1->getSourceRange();
+    return true;
+  }
+
+  if (!Arg2->getType()->isVectorType()) {
+    S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type)
+        << "Third" << TheCall->getDirectCallee() << Arg2->getType()
+        << Arg2->getSourceRange();
+    return true;
+  }
+
+  if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
+    S->Diag(TheCall->getBeginLoc(),
+            diag::err_typecheck_call_
diff erent_arg_types)
+        << Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
+        << Arg2->getSourceRange();
+    return true;
+  }
+
+  // caller has checked that Arg0 is a vector.
+  // check all three args have the same length.
+  if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
+      Arg1->getType()->getAs<VectorType>()->getNumElements()) {
+    S->Diag(TheCall->getBeginLoc(),
+            diag::err_typecheck_vector_lengths_not_equal)
+        << TheCall->getArg(0)->getType() << Arg1->getType()
+        << TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
+    return true;
+  }
+  TheCall->setType(Arg1->getType());
+  return false;
+}
+
 // Note: returning true in this case results in CheckBuiltinFunctionCall
 // returning an ExprError
 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -1563,6 +1636,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_select: {
+    if (SemaRef.checkArgCount(TheCall, 3))
+      return true;
+    if (CheckScalarOrVector(&SemaRef, TheCall, getASTContext().BoolTy, 0))
+      return true;
+    QualType ArgTy = TheCall->getArg(0)->getType();
+    if (ArgTy->isBooleanType() && CheckBoolSelect(&SemaRef, TheCall))
+      return true;
+    auto *VTy = ArgTy->getAs<VectorType>();
+    if (VTy && VTy->getElementType()->isBooleanType() &&
+        CheckVectorSelect(&SemaRef, TheCall))
+      return true;
+    break;
+  }
   case Builtin::BI__builtin_hlsl_elementwise_saturate:
   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
     if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))

diff  --git a/clang/test/CodeGenHLSL/builtins/select.hlsl b/clang/test/CodeGenHLSL/builtins/select.hlsl
new file mode 100644
index 00000000000000..cade938b71a2ba
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/select.hlsl
@@ -0,0 +1,54 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK
+
+// CHECK-LABEL: test_select_bool_int
+// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, i32 {{%.*}}, i32 {{%.*}}
+// CHECK: ret i32 [[SELECT]]
+int test_select_bool_int(bool cond0, int tVal, int fVal) {
+  return select<int>(cond0, tVal, fVal);
+}
+
+struct S { int a; };
+// CHECK-LABEL: test_select_infer
+// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, ptr {{%.*}}, ptr {{%.*}}
+// CHECK: store ptr [[SELECT]]
+// CHECK: ret void
+struct S test_select_infer(bool cond0, struct S tVal, struct S fVal) {
+  return select(cond0, tVal, fVal);
+}
+
+// CHECK-LABEL: test_select_bool_vector
+// CHECK: [[SELECT:%.*]] = select i1 {{%.*}}, <2 x i32> {{%.*}}, <2 x i32> {{%.*}}
+// CHECK: ret <2 x i32> [[SELECT]]
+int2 test_select_bool_vector(bool cond0, int2 tVal, int2 fVal) {
+  return select<int2>(cond0, tVal, fVal);
+}
+
+// CHECK-LABEL: test_select_vector_1
+// CHECK: [[SELECT:%.*]] = select <1 x i1> {{%.*}}, <1 x i32> {{%.*}}, <1 x i32> {{%.*}}
+// CHECK: ret <1 x i32> [[SELECT]]
+int1 test_select_vector_1(bool1 cond0, int1 tVals, int1 fVals) {
+  return select<int,1>(cond0, tVals, fVals);
+}
+
+// CHECK-LABEL: test_select_vector_2
+// CHECK: [[SELECT:%.*]] = select <2 x i1> {{%.*}}, <2 x i32> {{%.*}}, <2 x i32> {{%.*}}
+// CHECK: ret <2 x i32> [[SELECT]]
+int2 test_select_vector_2(bool2 cond0, int2 tVals, int2 fVals) {
+  return select<int,2>(cond0, tVals, fVals);
+}
+
+// CHECK-LABEL: test_select_vector_3
+// CHECK: [[SELECT:%.*]] = select <3 x i1> {{%.*}}, <3 x i32> {{%.*}}, <3 x i32> {{%.*}}
+// CHECK: ret <3 x i32> [[SELECT]]
+int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) {
+  return select<int,3>(cond0, tVals, fVals);
+}
+
+// CHECK-LABEL: test_select_vector_4
+// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> {{%.*}}, <4 x i32> {{%.*}}
+// CHECK: ret <4 x i32> [[SELECT]]
+int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) {
+  return select(cond0, tVals, fVals);
+}

diff  --git a/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
new file mode 100644
index 00000000000000..34b5fb6d54cd57
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/select-errors.hlsl
@@ -0,0 +1,119 @@
+// RUN: %clang_cc1 -finclude-default-header
+// -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only
+// -disable-llvm-passes -verify -verify-ignore-unexpected
+
+int test_no_arg() {
+  return select();
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template
+  // not viable: requires 3 arguments, but 0 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
+  // viable: requires 3 arguments, but 0 were provided}}
+}
+
+int test_too_few_args(bool p0) {
+  return select(p0);
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
+  // viable: requires 3 arguments, but 1 was provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
+  // viable: requires 3 arguments, but 1 was provided}}
+}
+
+int test_too_many_args(bool p0, int t0, int f0, int g0) {
+  return select<int>(p0, t0, f0, g0);
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
+  // viable: requires 3 arguments, but 4 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
+  // viable: requires 3 arguments, but 4 were provided}}
+}
+
+int test_select_first_arg_wrong_type(int1 p0, int t0, int f0) {
+  return select(p0, t0, f0);
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
+  // viable: no known conversion from 'vector<int, 1>' (vector of 1 'int' value)
+  // to 'bool' for 1st argument}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could
+  // not match 'vector<T, Sz>' against 'int'}}
+}
+
+int1 test_select_bool_vals_
diff _vecs(bool p0, int1 t0, int1 f0) {
+  return select<int1>(p0, t0, f0);
+  // expected-warning at -1 {{implicit conversion truncates vector:
+  // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
+  // (vector of 1 'int' value)}}
+}
+
+int2 test_select_vector_vals_not_vecs(bool2 p0, int t0,
+                                               int f0) {
+  return select(p0, t0, f0);
+  // expected-error at -1 {{no matching function for call to 'select'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored:
+  // could not match 'vector<T, Sz>' against 'int'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not
+  // viable: no known conversion from 'vector<bool, 2>'
+  // (vector of 2 'bool' values) to 'bool' for 1st argument}}
+}
+
+int1 test_select_vector_vals_wrong_size(bool2 p0, int1 t0, int1 f0) {
+  return select<int,1>(p0, t0, f0); // produce warnings
+  // expected-warning at -1 {{implicit conversion truncates vector:
+  // 'vector<bool, 2>' (vector of 2 'bool' values) to 'vector<bool, 1>'
+  // (vector of 1 'bool' value)}}
+  // expected-warning at -2 {{implicit conversion truncates vector:
+  // 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
+  // (vector of 1 'int' value)}}
+}
+
+// __builtin_hlsl_select tests
+int test_select_builtin_wrong_arg_count(bool p0, int t0) {
+  return __builtin_hlsl_select(p0, t0);
+  // expected-error at -1 {{too few arguments to function call, expected 3,
+  // have 2}}
+}
+
+// not a bool or a vector of bool. should be 2 errors.
+int test_select_builtin_first_arg_wrong_type1(int p0, int t0, int f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{passing 'int' to parameter of incompatible type
+  // 'bool'}}
+  // expected-error at -2 {{First argument to __builtin_hlsl_select must be of
+  // vector type}}
+  }
+
+int test_select_builtin_first_arg_wrong_type2(int1 p0, int t0, int f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{passing 'vector<int, 1>' (vector of 1 'int' value) to
+  // parameter of incompatible type 'bool'}}
+  // expected-error at -2 {{First argument to __builtin_hlsl_select must be of
+  // vector type}}
+}
+
+// if a bool last 2 args are of same type
+int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{arguments are of 
diff erent types ('int' vs 'double')}}
+}
+
+// if a vector second arg isnt a vector
+int2 test_select_builtin_second_arg_not_vector(bool2 p0, int t0, int2 f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{Second argument to __builtin_hlsl_select must be of
+  // vector type}}
+}
+
+// if a vector third arg isn't a vector
+int2 test_select_builtin_second_arg_not_vector(bool2 p0, int2 t0, int f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{Third argument to __builtin_hlsl_select must be of
+  // vector type}}
+}
+
+// if vector last 2 aren't same type (so both are vectors but wrong type)
+int2 test_select_builtin_
diff _types(bool1 p0, int1 t0, float1 f0) {
+  return __builtin_hlsl_select(p0, t0, f0);
+  // expected-error at -1 {{arguments are of 
diff erent types ('vector<int, [...]>'
+  // vs 'vector<float, [...]>')}}
+}


        


More information about the cfe-commits mailing list