[clang] [llvm] Add step builtins and step HLSL function to DirectX and SPIR-V backend (PR #106471)

Joshua Batista via cfe-commits cfe-commits at lists.llvm.org
Wed Sep 11 17:14:24 PDT 2024


https://github.com/bob80905 updated https://github.com/llvm/llvm-project/pull/106471

>From 7dfcf804ffcb850c7e3620866bc76f0971d82aaf Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Wed, 28 Aug 2024 16:35:36 -0700
Subject: [PATCH 1/2] add step

---
 clang/include/clang/Basic/Builtins.td         |   6 ++
 clang/lib/CodeGen/CGBuiltin.cpp               |  10 ++
 clang/lib/CodeGen/CGHLSLRuntime.h             |   1 +
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  33 ++++++
 clang/lib/Sema/SemaHLSL.cpp                   |  12 +++
 clang/test/CodeGenHLSL/builtins/step.hlsl     | 100 ++++++++++++++++++
 clang/test/SemaHLSL/BuiltIns/step-errors.hlsl |  31 ++++++
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |   1 +
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |   1 +
 .../Target/DirectX/DXILIntrinsicExpansion.cpp |  26 +++++
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  24 +++++
 llvm/test/CodeGen/DirectX/step.ll             |  79 ++++++++++++++
 .../CodeGen/SPIRV/hlsl-intrinsics/step.ll     |  33 ++++++
 13 files changed, 357 insertions(+)
 create mode 100644 clang/test/CodeGenHLSL/builtins/step.hlsl
 create mode 100644 clang/test/SemaHLSL/BuiltIns/step-errors.hlsl
 create mode 100644 llvm/test/CodeGen/DirectX/step.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/step.ll

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index ac33672a32b336..ba062d7b563749 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4751,6 +4751,12 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLStep: LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_step"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 4204c8ff276ab1..2a6a3b7f4852fa 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18695,6 +18695,16 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},
         nullptr, "hlsl.saturate");
   }
+  case Builtin::BI__builtin_hlsl_step: {
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    Value *Op1 = EmitScalarExpr(E->getArg(1));
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           E->getArg(1)->getType()->hasFloatingRepresentation() &&
+           "step operands must have a float representation");
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
+        ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
+  }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
     return EmitRuntimeCall(CGM.CreateRuntimeFunction(
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 55a4b97c160cd6..9c50181ade0945 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -80,6 +80,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Saturate, saturate)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(Step, step)
   GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
   GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 6d38b668fe770e..7ab7ea075c3109 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1691,6 +1691,39 @@ float3 sqrt(float3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_sqrt)
 float4 sqrt(float4);
 
+//===----------------------------------------------------------------------===//
+// step builtins
+//===----------------------------------------------------------------------===//
+
+/// \fn T step(T x, T y)
+/// \brief Returns 1 if the x parameter is greater than or equal to the y
+/// parameter; otherwise, 0. vector. \param x [in] The first floating-point
+/// value to compare. \param y [in] The first floating-point value to compare.
+///
+/// Step is based on the following formula: (x >= y) ? 1 : 0
+
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
+half step(half, half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
+half2 step(half2, half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
+half3 step(half3, half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
+half4 step(half4, half4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
+float step(float, float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
+float2 step(float2, float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
+float3 step(float3, float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_step)
+float4 step(float4, float4);
+
 //===----------------------------------------------------------------------===//
 // tan builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 714e8f5cfa9926..70ca25687de2db 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1490,6 +1490,18 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     TheCall->setType(ArgTyA);
     break;
   }
+  case Builtin::BI__builtin_hlsl_step: {
+    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+      return true;
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
+
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+    // return type is the same as the input type
+    TheCall->setType(ArgTyA);
+    break;
+  }
   // Note these are llvm builtins that we want to catch invalid intrinsic
   // generation. Normal handling of these builitns will occur elsewhere.
   case Builtin::BI__builtin_elementwise_bitreverse: {
diff --git a/clang/test/CodeGenHLSL/builtins/step.hlsl b/clang/test/CodeGenHLSL/builtins/step.hlsl
new file mode 100644
index 00000000000000..43312716449902
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/step.hlsl
@@ -0,0 +1,100 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN:   --check-prefixes=CHECK,DXIL_CHECK,DXIL_NATIVE_HALF,NATIVE_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,DXIL_CHECK,NO_HALF,DXIL_NO_HALF
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF,SPIR_NATIVE_HALF,SPIR_CHECK
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF,SPIR_NO_HALF,SPIR_CHECK
+
+// DXIL_NATIVE_HALF: define noundef half @
+// SPIR_NATIVE_HALF: define spir_func noundef half @
+// DXIL_NATIVE_HALF: call half @llvm.dx.step.f16(half
+// SPIR_NATIVE_HALF: call half @llvm.spv.step.f16(half
+// DXIL_NO_HALF: call float @llvm.dx.step.f32(float
+// SPIR_NO_HALF: call float @llvm.spv.step.f32(float
+// NATIVE_HALF: ret half
+// NO_HALF: ret float
+half test_step_half(half p0, half p1)
+{
+    return step(p0, p1);
+}
+// DXIL_NATIVE_HALF: define noundef <2 x half> @
+// SPIR_NATIVE_HALF: define spir_func noundef <2 x half> @
+// DXIL_NATIVE_HALF: call <2 x half> @llvm.dx.step.v2f16(<2 x half>
+// SPIR_NATIVE_HALF: call <2 x half> @llvm.spv.step.v2f16(<2 x half>
+// DXIL_NO_HALF: call <2 x float> @llvm.dx.step.v2f32(<2 x float>
+// SPIR_NO_HALF: call <2 x float> @llvm.spv.step.v2f32(<2 x float>
+// NATIVE_HALF: ret <2 x half> %hlsl.step
+// NO_HALF: ret <2 x float> %hlsl.step
+half2 test_step_half2(half2 p0, half2 p1)
+{
+    return step(p0, p1);
+}
+// DXIL_NATIVE_HALF: define noundef <3 x half> @
+// SPIR_NATIVE_HALF: define spir_func noundef <3 x half> @
+// DXIL_NATIVE_HALF: call <3 x half> @llvm.dx.step.v3f16(<3 x half>
+// SPIR_NATIVE_HALF: call <3 x half> @llvm.spv.step.v3f16(<3 x half>
+// DXIL_NO_HALF: call <3 x float> @llvm.dx.step.v3f32(<3 x float>
+// SPIR_NO_HALF: call <3 x float> @llvm.spv.step.v3f32(<3 x float>
+// NATIVE_HALF: ret <3 x half> %hlsl.step
+// NO_HALF: ret <3 x float> %hlsl.step
+half3 test_step_half3(half3 p0, half3 p1)
+{
+    return step(p0, p1);
+}
+// DXIL_NATIVE_HALF: define noundef <4 x half> @
+// SPIR_NATIVE_HALF: define spir_func noundef <4 x half> @
+// DXIL_NATIVE_HALF: call <4 x half> @llvm.dx.step.v4f16(<4 x half>
+// SPIR_NATIVE_HALF: call <4 x half> @llvm.spv.step.v4f16(<4 x half>
+// DXIL_NO_HALF: call <4 x float> @llvm.dx.step.v4f32(<4 x float>
+// SPIR_NO_HALF: call <4 x float> @llvm.spv.step.v4f32(<4 x float>
+// NATIVE_HALF: ret <4 x half> %hlsl.step
+// NO_HALF: ret <4 x float> %hlsl.step
+half4 test_step_half4(half4 p0, half4 p1)
+{
+    return step(p0, p1);
+}
+
+// DXIL_CHECK: define noundef float @
+// SPIR_CHECK: define spir_func noundef float @
+// DXIL_CHECK: call float @llvm.dx.step.f32(float
+// SPIR_CHECK: call float @llvm.spv.step.f32(float
+// CHECK: ret float
+float test_step_float(float p0, float p1)
+{
+    return step(p0, p1);
+}
+// DXIL_CHECK: define noundef <2 x float> @
+// SPIR_CHECK: define spir_func noundef <2 x float> @
+// DXIL_CHECK: %hlsl.step = call <2 x float> @llvm.dx.step.v2f32(
+// SPIR_CHECK: %hlsl.step = call <2 x float> @llvm.spv.step.v2f32(<2 x float>
+// CHECK: ret <2 x float> %hlsl.step
+float2 test_step_float2(float2 p0, float2 p1)
+{
+    return step(p0, p1);
+}
+// DXIL_CHECK: define noundef <3 x float> @
+// SPIR_CHECK: define spir_func noundef <3 x float> @
+// DXIL_CHECK: %hlsl.step = call <3 x float> @llvm.dx.step.v3f32(
+// SPIR_CHECK: %hlsl.step = call <3 x float> @llvm.spv.step.v3f32(<3 x float>
+// CHECK: ret <3 x float> %hlsl.step
+float3 test_step_float3(float3 p0, float3 p1)
+{
+    return step(p0, p1);
+}
+// DXIL_CHECK: define noundef <4 x float> @
+// SPIR_CHECK: define spir_func noundef <4 x float> @
+// DXIL_CHECK: %hlsl.step = call <4 x float> @llvm.dx.step.v4f32(
+// SPIR_CHECK: %hlsl.step = call <4 x float> @llvm.spv.step.v4f32(
+// CHECK: ret <4 x float> %hlsl.step
+float4 test_step_float4(float4 p0, float4 p1)
+{
+    return step(p0, p1);
+}
\ No newline at end of file
diff --git a/clang/test/SemaHLSL/BuiltIns/step-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/step-errors.hlsl
new file mode 100644
index 00000000000000..ccd21847f2f367
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/step-errors.hlsl
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -disable-llvm-passes -verify -verify-ignore-unexpected
+
+void test_too_few_arg()
+{
+  return __builtin_hlsl_step();
+  // expected-error at -1 {{too few arguments to function call, expected 2, have 0}}
+}
+
+void test_too_many_arg(float2 p0)
+{
+  return __builtin_hlsl_step(p0, p0, p0);
+  // expected-error at -1 {{too many arguments to function call, expected 2, have 3}}
+}
+
+bool builtin_bool_to_float_type_promotion(bool p1)
+{
+  return __builtin_hlsl_step(p1, p1);
+  // expected-error at -1 {passing 'bool' to parameter of incompatible type 'float'}}
+}
+
+bool builtin_step_int_to_float_promotion(int p1)
+{
+  return __builtin_hlsl_step(p1, p1);
+  // expected-error at -1 {{passing 'int' to parameter of incompatible type 'float'}}
+}
+
+bool2 builtin_step_int2_to_float2_promotion(int2 p1)
+{
+  return __builtin_hlsl_step(p1, p1);
+  // expected-error at -1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
+}
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 32af50b25f3904..741f1bb6cbe3cc 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -79,4 +79,5 @@ def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
 def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
 def int_dx_rcp  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
+def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>]>;
 }
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 63d9ba43a1183b..5ceaed7ee711fb 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -67,6 +67,7 @@ let TargetPrefix = "spv" in {
   def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
   def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
   def int_spv_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
+  def int_spv_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [LLVMMatchType<0>, llvm_anyfloat_ty]>;
   def int_spv_fdot :
     DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
     [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 2daa4f825c3b25..a04d6acc1a1352 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -48,6 +48,7 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::dx_fdot:
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
+  case Intrinsic::dx_step:
     return true;
   }
   return false;
@@ -320,6 +321,28 @@ static Value *expandPowIntrinsic(CallInst *Orig) {
   return Exp2Call;
 }
 
+static Value *expandStepIntrinsic(CallInst *Orig) {
+
+  Value *X = Orig->getOperand(0);
+  Value *Y = Orig->getOperand(1);
+  Type *Ty = X->getType();
+  IRBuilder<> Builder(Orig);
+
+  Constant *one = ConstantFP::get(Ty->getScalarType(), 1.0);
+  Constant *zero = ConstantFP::get(Ty->getScalarType(), 0.0);
+  Value *cond = Builder.CreateFCmpOLT(Y, X);
+
+  if (Ty != Ty->getScalarType()) {
+    auto *XVec = dyn_cast<FixedVectorType>(Ty);
+    one = ConstantVector::getSplat(
+        ElementCount::getFixed(XVec->getNumElements()), one);
+    zero = ConstantVector::getSplat(
+        ElementCount::getFixed(XVec->getNumElements()), zero);
+  }
+
+  return Builder.CreateSelect(cond, zero, one);
+}
+
 static Intrinsic::ID getMaxForClamp(Type *ElemTy,
                                     Intrinsic::ID ClampIntrinsic) {
   if (ClampIntrinsic == Intrinsic::dx_uclamp)
@@ -402,6 +425,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
   case Intrinsic::dx_udot:
     Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
     break;
+  case Intrinsic::dx_step:
+    Result = expandStepIntrinsic(Orig);
+    break;
   }
 
   if (Result) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 9e10d947081cc3..23cb42fb97130f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -259,6 +259,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I) const;
 
+  bool selectStep(Register ResVReg, const SPIRVType *ResType,
+                  MachineInstr &I) const;
+
   bool selectUnmergeValues(MachineInstr &I) const;
 
   Register buildI32Constant(uint32_t Val, MachineInstr &I,
@@ -1603,6 +1606,25 @@ bool SPIRVInstructionSelector::selectSaturate(Register ResVReg,
       .constrainAllUses(TII, TRI, RBI);
 }
 
+bool SPIRVInstructionSelector::selectStep(Register ResVReg,
+                                          const SPIRVType *ResType,
+                                          MachineInstr &I) const {
+
+  assert(I.getNumOperands() == 4);
+  assert(I.getOperand(2).isReg());
+  assert(I.getOperand(3).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+      .addImm(GL::Step)
+      .addUse(I.getOperand(2).getReg())
+      .addUse(I.getOperand(3).getReg())
+      .constrainAllUses(TII, TRI, RBI);
+}
+
 bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
                                                 const SPIRVType *ResType,
                                                 MachineInstr &I) const {
@@ -2351,6 +2373,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   } break;
   case Intrinsic::spv_saturate:
     return selectSaturate(ResVReg, ResType, I);
+  case Intrinsic::spv_step:
+    return selectStep(ResVReg, ResType, I);
   default: {
     std::string DiagMsg;
     raw_string_ostream OS(DiagMsg);
diff --git a/llvm/test/CodeGen/DirectX/step.ll b/llvm/test/CodeGen/DirectX/step.ll
new file mode 100644
index 00000000000000..9a25a371f6efd0
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/step.ll
@@ -0,0 +1,79 @@
+; RUN: opt -S  -dxil-intrinsic-expansion  < %s | FileCheck %s --check-prefix=CHECK
+; RUN: opt -S  -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s --check-prefix=CHECK
+
+; Make sure dxil operation function calls for step are generated for half/float.
+
+declare half @llvm.dx.step.f16(half, half)
+declare <2 x half> @llvm.dx.step.v2f16(<2 x half>, <2 x half>)
+declare <3 x half> @llvm.dx.step.v3f16(<3 x half>, <3 x half>)
+declare <4 x half> @llvm.dx.step.v4f16(<4 x half>, <4 x half>)
+
+declare float @llvm.dx.step.f32(float, float)
+declare <2 x float> @llvm.dx.step.v2f32(<2 x float>, <2 x float>)
+declare <3 x float> @llvm.dx.step.v3f32(<3 x float>, <3 x float>)
+declare <4 x float> @llvm.dx.step.v4f32(<4 x float>, <4 x float>)
+
+define noundef half @test_step_half(half noundef %p0, half noundef %p1) {
+entry:
+  ; CHECK: %0 = fcmp olt half %p1, %p0
+  ; CHECK: %1 = select i1 %0, half 0xH0000, half 0xH3C00
+  ; DOPCHECK: asdahg
+  %hlsl.step = call half @llvm.dx.step.f16(half %p0, half %p1)
+  ret half %hlsl.step
+}
+
+define noundef <2 x half> @test_step_half2(<2 x half> noundef %p0, <2 x half> noundef %p1) {
+entry:
+  ; CHECK: %0 = fcmp olt <2 x half> %p1, %p0
+  ; CHECK: %1 = select <2 x i1> %0, <2 x half> zeroinitializer, <2 x half> <half 0xH3C00, half 0xH3C00>
+  %hlsl.step = call <2 x half> @llvm.dx.step.v2f16(<2 x half> %p0, <2 x half> %p1)
+  ret <2 x half> %hlsl.step
+}
+
+define noundef <3 x half> @test_step_half3(<3 x half> noundef %p0, <3 x half> noundef %p1) {
+entry:
+  ; CHECK: %0 = fcmp olt <3 x half> %p1, %p0
+  ; CHECK: %1 = select <3 x i1> %0, <3 x half> zeroinitializer, <3 x half> <half 0xH3C00, half 0xH3C00, half 0xH3C00>
+  %hlsl.step = call <3 x half> @llvm.dx.step.v3f16(<3 x half> %p0, <3 x half> %p1)
+  ret <3 x half> %hlsl.step
+}
+
+define noundef <4 x half> @test_step_half4(<4 x half> noundef %p0, <4 x half> noundef %p1) {
+entry:
+  ; CHECK: %0 = fcmp olt <4 x half> %p1, %p0
+  ; CHECK: %1 = select <4 x i1> %0, <4 x half> zeroinitializer, <4 x half> <half 0xH3C00, half 0xH3C00, half 0xH3C00, half 0xH3C00>
+  %hlsl.step = call <4 x half> @llvm.dx.step.v4f16(<4 x half> %p0, <4 x half> %p1)
+  ret <4 x half> %hlsl.step
+}
+
+define noundef float @test_step_float(float noundef %p0, float noundef %p1) {
+entry:
+  ; CHECK: %0 = fcmp olt float %p1, %p0
+  ; CHECK: %1 = select i1 %0, float 0.000000e+00, float 1.000000e+00
+  %hlsl.step = call float @llvm.dx.step.f32(float %p0, float %p1)
+  ret float %hlsl.step
+}
+
+define noundef <2 x float> @test_step_float2(<2 x float> noundef %p0, <2 x float> noundef %p1) {
+entry:
+  ; CHECK: %0 = fcmp olt <2 x float> %p1, %p0
+  ; CHECK: %1 = select <2 x i1> %0, <2 x float> zeroinitializer, <2 x float> <float 1.000000e+00, float 1.000000e+00>
+  %hlsl.step = call <2 x float> @llvm.dx.step.v2f32(<2 x float> %p0, <2 x float> %p1)
+  ret <2 x float> %hlsl.step
+}
+
+define noundef <3 x float> @test_step_float3(<3 x float> noundef %p0, <3 x float> noundef %p1) {
+entry:
+  ; CHECK: %0 = fcmp olt <3 x float> %p1, %p0
+  ; CHECK: %1 = select <3 x i1> %0, <3 x float> zeroinitializer, <3 x float> <float 1.000000e+00, float 1.000000e+00, float 1.000000e+00>
+  %hlsl.step = call <3 x float> @llvm.dx.step.v3f32(<3 x float> %p0, <3 x float> %p1)
+  ret <3 x float> %hlsl.step
+}
+
+define noundef <4 x float> @test_step_float4(<4 x float> noundef %p0, <4 x float> noundef %p1) {
+entry:
+  ; CHECK: %0 = fcmp olt <4 x float> %p1, %p0
+  ; CHECK: %1 = select <4 x i1> %0, <4 x float> zeroinitializer, <4 x float> <float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00>
+  %hlsl.step = call <4 x float> @llvm.dx.step.v4f32(<4 x float> %p0, <4 x float> %p1)
+  ret <4 x float> %hlsl.step
+}
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/step.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/step.ll
new file mode 100644
index 00000000000000..c45810f362970f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/step.ll
@@ -0,0 +1,33 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure SPIRV operation function calls for step are lowered correctly.
+
+; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450"
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
+; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
+
+define noundef <4 x half> @step_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
+entry:
+  ; CHECK: %[[#]] = OpFunction %[[#vec4_float_16]] None %[[#]]
+  ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]]
+  ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_16]]
+  ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_16]] %[[#op_ext_glsl]] Step %[[#arg0]] %[[#arg1]]
+  %hlsl.step = call <4 x half> @llvm.spv.step.v4f16(<4 x half> %a, <4 x half> %b)
+  ret <4 x half> %hlsl.step
+}
+
+define noundef <4 x float> @step_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
+entry:
+  ; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#]]
+  ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
+  ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
+  ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] Step %[[#arg0]] %[[#arg1]]
+  %hlsl.step = call <4 x float> @llvm.spv.step.v4f32(<4 x float> %a, <4 x float> %b)
+  ret <4 x float> %hlsl.step
+}
+
+declare <4 x half> @llvm.spv.step.v4f16(<4 x half>, <4 x half>)
+declare <4 x float> @llvm.spv.step.v4f32(<4 x float>, <4 x float>)
\ No newline at end of file

>From af4a0f665572eaf7ef1d7919d5befc18490e4e41 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Wed, 11 Sep 2024 17:14:06 -0700
Subject: [PATCH 2/2] address farzon / justin, formatting

---
 clang/include/clang/Basic/Builtins.td         |  1 +
 clang/lib/Sema/SemaHLSL.cpp                   |  4 +-
 clang/test/CodeGenHLSL/builtins/step.hlsl     | 74 ++++++++-----------
 clang/test/SemaHLSL/BuiltIns/step-errors.hlsl |  2 +-
 .../Target/DirectX/DXILIntrinsicExpansion.cpp | 18 ++---
 llvm/test/CodeGen/DirectX/step.ll             |  1 -
 .../CodeGen/SPIRV/hlsl-intrinsics/step.ll     |  2 +-
 7 files changed, 43 insertions(+), 59 deletions(-)

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b9fedeb3c9e92b..6cf03d27055cd9 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4780,6 +4780,7 @@ def HLSLStep: LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_step"];
   let Attributes = [NoThrow, Const];
   let Prototype = "void(...)";
+}
 
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 2b55538852d3ba..527718c8e7e324 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1748,10 +1748,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     break;
   }
   case Builtin::BI__builtin_hlsl_step: {
-    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
-      return true;
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
+    if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
+      return true;
 
     ExprResult A = TheCall->getArg(0);
     QualType ArgTyA = A.get()->getType();
diff --git a/clang/test/CodeGenHLSL/builtins/step.hlsl b/clang/test/CodeGenHLSL/builtins/step.hlsl
index 43312716449902..442f4930ca579c 100644
--- a/clang/test/CodeGenHLSL/builtins/step.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/step.hlsl
@@ -1,60 +1,52 @@
 // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
 // RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
 // RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
-// RUN:   --check-prefixes=CHECK,DXIL_CHECK,DXIL_NATIVE_HALF,NATIVE_HALF
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF \
+// RUN:   -DFNATTRS=noundef -DTARGET=dx
 // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
 // RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
-// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,DXIL_CHECK,NO_HALF,DXIL_NO_HALF
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF \
+// RUN:   -DFNATTRS=noundef -DTARGET=dx
 // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
 // RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
 // RUN:   -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
-// RUN:   --check-prefixes=CHECK,NATIVE_HALF,SPIR_NATIVE_HALF,SPIR_CHECK
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF \
+// RUN:   -DFNATTRS="spir_func noundef" -DTARGET=spv
 // RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
 // RUN:   spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
-// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF,SPIR_NO_HALF,SPIR_CHECK
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF \
+// RUN:   -DFNATTRS="spir_func noundef" -DTARGET=spv
 
-// DXIL_NATIVE_HALF: define noundef half @
-// SPIR_NATIVE_HALF: define spir_func noundef half @
-// DXIL_NATIVE_HALF: call half @llvm.dx.step.f16(half
-// SPIR_NATIVE_HALF: call half @llvm.spv.step.f16(half
-// DXIL_NO_HALF: call float @llvm.dx.step.f32(float
-// SPIR_NO_HALF: call float @llvm.spv.step.f32(float
+// NATIVE_HALF: define [[FNATTRS]] half @
+// NATIVE_HALF: call half @llvm.[[TARGET]].step.f16(half
+// NO_HALF: call float @llvm.[[TARGET]].step.f32(float
 // NATIVE_HALF: ret half
 // NO_HALF: ret float
 half test_step_half(half p0, half p1)
 {
     return step(p0, p1);
 }
-// DXIL_NATIVE_HALF: define noundef <2 x half> @
-// SPIR_NATIVE_HALF: define spir_func noundef <2 x half> @
-// DXIL_NATIVE_HALF: call <2 x half> @llvm.dx.step.v2f16(<2 x half>
-// SPIR_NATIVE_HALF: call <2 x half> @llvm.spv.step.v2f16(<2 x half>
-// DXIL_NO_HALF: call <2 x float> @llvm.dx.step.v2f32(<2 x float>
-// SPIR_NO_HALF: call <2 x float> @llvm.spv.step.v2f32(<2 x float>
+// NATIVE_HALF: define [[FNATTRS]] <2 x half> @
+// NATIVE_HALF: call <2 x half> @llvm.[[TARGET]].step.v2f16(<2 x half>
+// NO_HALF: call <2 x float> @llvm.[[TARGET]].step.v2f32(<2 x float>
 // NATIVE_HALF: ret <2 x half> %hlsl.step
 // NO_HALF: ret <2 x float> %hlsl.step
 half2 test_step_half2(half2 p0, half2 p1)
 {
     return step(p0, p1);
 }
-// DXIL_NATIVE_HALF: define noundef <3 x half> @
-// SPIR_NATIVE_HALF: define spir_func noundef <3 x half> @
-// DXIL_NATIVE_HALF: call <3 x half> @llvm.dx.step.v3f16(<3 x half>
-// SPIR_NATIVE_HALF: call <3 x half> @llvm.spv.step.v3f16(<3 x half>
-// DXIL_NO_HALF: call <3 x float> @llvm.dx.step.v3f32(<3 x float>
-// SPIR_NO_HALF: call <3 x float> @llvm.spv.step.v3f32(<3 x float>
+// NATIVE_HALF: define [[FNATTRS]] <3 x half> @
+// NATIVE_HALF: call <3 x half> @llvm.[[TARGET]].step.v3f16(<3 x half>
+// NO_HALF: call <3 x float> @llvm.[[TARGET]].step.v3f32(<3 x float>
 // NATIVE_HALF: ret <3 x half> %hlsl.step
 // NO_HALF: ret <3 x float> %hlsl.step
 half3 test_step_half3(half3 p0, half3 p1)
 {
     return step(p0, p1);
 }
-// DXIL_NATIVE_HALF: define noundef <4 x half> @
-// SPIR_NATIVE_HALF: define spir_func noundef <4 x half> @
-// DXIL_NATIVE_HALF: call <4 x half> @llvm.dx.step.v4f16(<4 x half>
-// SPIR_NATIVE_HALF: call <4 x half> @llvm.spv.step.v4f16(<4 x half>
-// DXIL_NO_HALF: call <4 x float> @llvm.dx.step.v4f32(<4 x float>
-// SPIR_NO_HALF: call <4 x float> @llvm.spv.step.v4f32(<4 x float>
+// NATIVE_HALF: define [[FNATTRS]] <4 x half> @
+// NATIVE_HALF: call <4 x half> @llvm.[[TARGET]].step.v4f16(<4 x half>
+// NO_HALF: call <4 x float> @llvm.[[TARGET]].step.v4f32(<4 x float>
 // NATIVE_HALF: ret <4 x half> %hlsl.step
 // NO_HALF: ret <4 x float> %hlsl.step
 half4 test_step_half4(half4 p0, half4 p1)
@@ -62,39 +54,31 @@ half4 test_step_half4(half4 p0, half4 p1)
     return step(p0, p1);
 }
 
-// DXIL_CHECK: define noundef float @
-// SPIR_CHECK: define spir_func noundef float @
-// DXIL_CHECK: call float @llvm.dx.step.f32(float
-// SPIR_CHECK: call float @llvm.spv.step.f32(float
+// CHECK: define [[FNATTRS]] float @
+// CHECK: call float @llvm.[[TARGET]].step.f32(float
 // CHECK: ret float
 float test_step_float(float p0, float p1)
 {
     return step(p0, p1);
 }
-// DXIL_CHECK: define noundef <2 x float> @
-// SPIR_CHECK: define spir_func noundef <2 x float> @
-// DXIL_CHECK: %hlsl.step = call <2 x float> @llvm.dx.step.v2f32(
-// SPIR_CHECK: %hlsl.step = call <2 x float> @llvm.spv.step.v2f32(<2 x float>
+// CHECK: define [[FNATTRS]] <2 x float> @
+// CHECK: %hlsl.step = call <2 x float> @llvm.[[TARGET]].step.v2f32(
 // CHECK: ret <2 x float> %hlsl.step
 float2 test_step_float2(float2 p0, float2 p1)
 {
     return step(p0, p1);
 }
-// DXIL_CHECK: define noundef <3 x float> @
-// SPIR_CHECK: define spir_func noundef <3 x float> @
-// DXIL_CHECK: %hlsl.step = call <3 x float> @llvm.dx.step.v3f32(
-// SPIR_CHECK: %hlsl.step = call <3 x float> @llvm.spv.step.v3f32(<3 x float>
+// CHECK: define [[FNATTRS]] <3 x float> @
+// CHECK: %hlsl.step = call <3 x float> @llvm.[[TARGET]].step.v3f32(
 // CHECK: ret <3 x float> %hlsl.step
 float3 test_step_float3(float3 p0, float3 p1)
 {
     return step(p0, p1);
 }
-// DXIL_CHECK: define noundef <4 x float> @
-// SPIR_CHECK: define spir_func noundef <4 x float> @
-// DXIL_CHECK: %hlsl.step = call <4 x float> @llvm.dx.step.v4f32(
-// SPIR_CHECK: %hlsl.step = call <4 x float> @llvm.spv.step.v4f32(
+// CHECK: define [[FNATTRS]] <4 x float> @
+// CHECK: %hlsl.step = call <4 x float> @llvm.[[TARGET]].step.v4f32(
 // CHECK: ret <4 x float> %hlsl.step
 float4 test_step_float4(float4 p0, float4 p1)
 {
     return step(p0, p1);
-}
\ No newline at end of file
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/step-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/step-errors.hlsl
index ccd21847f2f367..823585201ca62d 100644
--- a/clang/test/SemaHLSL/BuiltIns/step-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/step-errors.hlsl
@@ -28,4 +28,4 @@ bool2 builtin_step_int2_to_float2_promotion(int2 p1)
 {
   return __builtin_hlsl_step(p1, p1);
   // expected-error at -1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
-}
\ No newline at end of file
+}
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 8ec36103cc16bf..8516a240b21f7d 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -330,19 +330,19 @@ static Value *expandStepIntrinsic(CallInst *Orig) {
   Type *Ty = X->getType();
   IRBuilder<> Builder(Orig);
 
-  Constant *one = ConstantFP::get(Ty->getScalarType(), 1.0);
-  Constant *zero = ConstantFP::get(Ty->getScalarType(), 0.0);
-  Value *cond = Builder.CreateFCmpOLT(Y, X);
+  Constant *One = ConstantFP::get(Ty->getScalarType(), 1.0);
+  Constant *Zero = ConstantFP::get(Ty->getScalarType(), 0.0);
+  Value *Cond = Builder.CreateFCmpOLT(Y, X);
 
   if (Ty != Ty->getScalarType()) {
     auto *XVec = dyn_cast<FixedVectorType>(Ty);
-    one = ConstantVector::getSplat(
-        ElementCount::getFixed(XVec->getNumElements()), one);
-    zero = ConstantVector::getSplat(
-        ElementCount::getFixed(XVec->getNumElements()), zero);
+    One = ConstantVector::getSplat(
+        ElementCount::getFixed(XVec->getNumElements()), One);
+    Zero = ConstantVector::getSplat(
+        ElementCount::getFixed(XVec->getNumElements()), Zero);
   }
 
-  return Builder.CreateSelect(cond, zero, one);
+  return Builder.CreateSelect(Cond, Zero, One);
 }
 
 static Intrinsic::ID getMaxForClamp(Type *ElemTy,
@@ -456,9 +456,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
   case Intrinsic::dx_sign:
     Result = expandSignIntrinsic(Orig);
     break;
-  }
   case Intrinsic::dx_step:
     Result = expandStepIntrinsic(Orig);
+  }
   if (Result) {
     Orig->replaceAllUsesWith(Result);
     Orig->eraseFromParent();
diff --git a/llvm/test/CodeGen/DirectX/step.ll b/llvm/test/CodeGen/DirectX/step.ll
index 9a25a371f6efd0..0393c1533cb939 100644
--- a/llvm/test/CodeGen/DirectX/step.ll
+++ b/llvm/test/CodeGen/DirectX/step.ll
@@ -17,7 +17,6 @@ define noundef half @test_step_half(half noundef %p0, half noundef %p1) {
 entry:
   ; CHECK: %0 = fcmp olt half %p1, %p0
   ; CHECK: %1 = select i1 %0, half 0xH0000, half 0xH3C00
-  ; DOPCHECK: asdahg
   %hlsl.step = call half @llvm.dx.step.f16(half %p0, half %p1)
   ret half %hlsl.step
 }
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/step.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/step.ll
index c45810f362970f..bb50d8c790f8ad 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/step.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/step.ll
@@ -30,4 +30,4 @@ entry:
 }
 
 declare <4 x half> @llvm.spv.step.v4f16(<4 x half>, <4 x half>)
-declare <4 x float> @llvm.spv.step.v4f32(<4 x float>, <4 x float>)
\ No newline at end of file
+declare <4 x float> @llvm.spv.step.v4f32(<4 x float>, <4 x float>)



More information about the cfe-commits mailing list