[clang] [llvm] [HLSL][DXIL] Implement `refract` intrinsic (PR #147342)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 14 10:30:54 PDT 2025


https://github.com/raoanag updated https://github.com/llvm/llvm-project/pull/147342

>From 2b5c51361a793abc80f45b04e2830d1ec551ef73 Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Wed, 16 Apr 2025 11:47:27 -0700
Subject: [PATCH 01/23] Refract implementation

---
 clang/test/CodeGenHLSL/builtins/refract.hlsl  | 92 +++++++++++++++++++
 clang/test/CodeGenSPIRV/Builtins/refract.c    | 32 +++++++
 .../SemaHLSL/BuiltIns/refract-errors.hlsl     | 74 +++++++++++++++
 .../CodeGen/SPIRV/hlsl-intrinsics/refract.ll  | 33 +++++++
 4 files changed, 231 insertions(+)
 create mode 100644 clang/test/CodeGenHLSL/builtins/refract.hlsl
 create mode 100644 clang/test/CodeGenSPIRV/Builtins/refract.c
 create mode 100644 clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll

diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
new file mode 100644
index 0000000000000..4dc5b66251b62
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -0,0 +1,92 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -O1 -o - | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -triple \
+// RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
+// RUN:   -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_refract_floatff(
+// CHECK-SAME: float noundef nofpclass(nan inf) [[I:%.*]], float noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[I]], 2.000000e+00
+// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[N]], [[N]]
+// CHECK-NEXT:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[TMP0]], [[MUL_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float [[I]], [[MUL2_I]]
+// CHECK-NEXT:    ret float [[SUB_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_refract_floatff(
+// SPVCHECK-SAME: float noundef nofpclass(nan inf) [[I:%.*]], float noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[I]], 2.000000e+00
+// SPVCHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[N]], [[N]]
+// SPVCHECK-NEXT:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[TMP0]], [[MUL_I]]
+// SPVCHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float [[I]], [[MUL2_I]]
+// SPVCHECK-NEXT:    ret float [[SUB_I]]
+//
+float test_refract_float(float I, float N, float eta) {
+    return refract(I, N, eta);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x float> @_Z19test_refract_float2Dv2_fS_(
+// CHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[I:%.*]], <2 x float> noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v2f32(<2 x float> [[I]], <2 x float> [[N]])
+// CHECK-NEXT:    [[DOTSCALAR:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], 2.000000e+00
+// CHECK-NEXT:    [[TMP0:%.*]] = insertelement <2 x float> poison, float [[DOTSCALAR]], i64 0
+// CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[TMP0]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> [[TMP1]], [[N]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x float> [[I]], [[MUL1_I]]
+// CHECK-NEXT:    ret <2 x float> [[SUB_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x float> @_Z19test_refract_float2Dv2_fS_(
+// SPVCHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[I:%.*]], <2 x float> noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[SPV_refract_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x float> @llvm.spv.refract.v2f32(<2 x float> [[I]], <2 x float> [[N]])
+// SPVCHECK-NEXT:    ret <2 x float> [[SPV_refract_I]]
+//
+float2 test_refract_float2(float2 I, float2 N, float eta) {
+    return refract(I, N, eta);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x float> @_Z19test_refract_float3Dv3_fS_(
+// CHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[I:%.*]], <3 x float> noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v3f32(<3 x float> [[I]], <3 x float> [[N]])
+// CHECK-NEXT:    [[DOTSCALAR:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], 2.000000e+00
+// CHECK-NEXT:    [[TMP0:%.*]] = insertelement <3 x float> poison, float [[DOTSCALAR]], i64 0
+// CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <3 x float> [[TMP0]], <3 x float> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> [[TMP1]], [[N]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x float> [[I]], [[MUL1_I]]
+// CHECK-NEXT:    ret <3 x float> [[SUB_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <3 x float> @_Z19test_refract_float3Dv3_fS_(
+// SPVCHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[I:%.*]], <3 x float> noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[SPV_refract_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <3 x float> @llvm.spv.refract.v3f32(<3 x float> [[I]], <3 x float> [[N]])
+// SPVCHECK-NEXT:    ret <3 x float> [[SPV_refract_I]]
+//
+float3 test_refract_float3(float3 I, float3 N, float eta) {
+    return refract(I, N, eta);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <4 x float> @_Z19test_refract_float4Dv4_fS_(
+// CHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[I:%.*]], <4 x float> noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v4f32(<4 x float> [[I]], <4 x float> [[N]])
+// CHECK-NEXT:    [[DOTSCALAR:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], 2.000000e+00
+// CHECK-NEXT:    [[TMP0:%.*]] = insertelement <4 x float> poison, float [[DOTSCALAR]], i64 0
+// CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[TMP0]], <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> [[TMP1]], [[N]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x float> [[I]], [[MUL1_I]]
+// CHECK-NEXT:    ret <4 x float> [[SUB_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <4 x float> @_Z19test_refract_float4Dv4_fS_(
+// SPVCHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[I:%.*]], <4 x float> noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[SPV_refract_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32(<4 x float> [[I]], <4 x float> [[N]])
+// SPVCHECK-NEXT:    ret <4 x float> [[SPV_refract_I]]
+//
+float4 test_refract_float4(float4 I, float4 N, float eta) {
+    return refract(I, N, eta);
+}
diff --git a/clang/test/CodeGenSPIRV/Builtins/refract.c b/clang/test/CodeGenSPIRV/Builtins/refract.c
new file mode 100644
index 0000000000000..06498554bd4d1
--- /dev/null
+++ b/clang/test/CodeGenSPIRV/Builtins/refract.c
@@ -0,0 +1,32 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+
+// RUN: %clang_cc1 -O1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s
+
+typedef float float2 __attribute__((ext_vector_type(2)));
+typedef float float3 __attribute__((ext_vector_type(3)));
+typedef float float4 __attribute__((ext_vector_type(4)));
+
+// CHECK-LABEL: define spir_func <2 x float> @test_refract_float2(
+// CHECK-SAME: <2 x float> noundef [[X:%.*]], <2 x float> noundef [[Y:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <2 x float> @llvm.spv.refract.v2f32(<2 x float> [[X]], <2 x float> [[Y]])
+// CHECK-NEXT:    ret <2 x float> [[SPV_REFRACT]]
+//
+float2 test_refract_float2(float2 X, float2 Y, float eta) { return __builtin_spirv_refract(X, Y, eta); }
+
+// CHECK-LABEL: define spir_func <3 x float> @test_refract_float3(
+// CHECK-SAME: <3 x float> noundef [[X:%.*]], <3 x float> noundef [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <3 x float> @llvm.spv.refract.v3f32(<3 x float> [[X]], <3 x float> [[Y]])
+// CHECK-NEXT:    ret <3 x float> [[SPV_REFRACT]]
+//
+float3 test_refract_float3(float3 X, float3 Y, float eta) { return __builtin_spirv_refract(X, Y, eta); }
+
+// CHECK-LABEL: define spir_func <4 x float> @test_refract_float4(
+// CHECK-SAME: <4 x float> noundef [[X:%.*]], <4 x float> noundef [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <4 x float> @llvm.spv.refract.v4f32(<4 x float> [[X]], <4 x float> [[Y]])
+// CHECK-NEXT:    ret <4 x float> [[SPV_REFRACT]]
+//
+float4 test_refract_float4(float4 X, float4 Y, float eta) { return __builtin_spirv_refract(X, Y, eta); }
+
diff --git a/clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
new file mode 100644
index 0000000000000..eee913e6bb6e7
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
@@ -0,0 +1,74 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify
+
+float test_no_second_arg(float3 p0) {
+  return refract(p0);
+  // expected-error at -1 {{no matching function for call to 'refract'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 1 was provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 1 was provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 1 was provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 1 was provided}}
+}
+
+float test_no_third_arg(float3 p0) {
+  return refract(p0, p0);
+  // expected-error at -1 {{no matching function for call to 'refract'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 2 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 2 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 2 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 2 were provided}}
+}
+
+float test_too_many_arg(float2 p0) {
+  return refract(p0, p0, p0, p0);
+  // expected-error at -1 {{no matching function for call to 'refract'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 4 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 4 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 4 were provided}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires 3 arguments, but 4 were provided}}
+}
+
+float test_double_inputs(double p0, double p1, double p2) {
+  return refract(p0, p1, p2);
+  // expected-error at -1  {{no matching function for call to 'refract'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored}}
+}
+
+float test_int_inputs(int p0, int p1, int p2) {
+  return refract(p0, p1, p2);
+  // expected-error at -1  {{no matching function for call to 'refract'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored}}
+}
+
+float1 test_vec1_inputs(float1 p0, float1 p1, float1 p2) {
+  return refract(p0, p1, p2);
+  // expected-error at -1  {{no matching function for call to 'refract'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: substitution failure [with T = float1]: no type named 'Type' in 'hlsl::__detail::enable_if<false, vector<float, 1>>'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: substitution failure [with T = float1]: no type named 'Type' in 'hlsl::__detail::enable_if<false, vector<float, 1>>'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: substitution failure [with L = 1]: no type named 'Type' in 'hlsl::__detail::enable_if<false, half>'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: substitution failure [with L = 1]: no type named 'Type' in 'hlsl::__detail::enable_if<false, float>'}}
+}
+
+float3 test_mixed_datatype_inputs(float3 p0, float3 p1, half p2) {
+  return refract(p0, p1, p2);
+}
+
+half3 test_mixed_datatype_inputs(half3 p0, half3 p1, float p2) {
+  return refract(p0, p1, p2);
+}
+
+typedef float float5 __attribute__((ext_vector_type(5)));
+
+float5 test_vec5_inputs(float5 p0, float5 p1,  float p2) {
+  return refract(p0, p1, p2);
+  // expected-error at -1  {{no matching function for call to 'refract'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: deduced conflicting types for parameter 'T' ('float5' (vector of 5 'float' values) vs. 'float')}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: deduced conflicting types for parameter 'T' ('float5' (vector of 5 'float' values) vs. 'float')}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: substitution failure [with L = 5]: no type named 'Type' in 'hlsl::__detail::enable_if<false, half>'}}
+  // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: substitution failure [with L = 5]: no type named 'Type' in 'hlsl::__detail::enable_if<false, float>'}}
+}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
new file mode 100644
index 0000000000000..7dc6d24f651de
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
@@ -0,0 +1,33 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure SPIRV operation function calls for refract are lowered correctly.
+
+; CHECK-DAG: %[[#op_ext_glsl:]] = OpExtInstImport "GLSL.std.450"
+; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
+
+define noundef <4 x half> @refract_half4(<4 x half> noundef %a, <4 x half> noundef %b, half %eta) {
+entry:
+  ; CHECK: %[[#]] = OpFunction %[[#vec4_float_16]] None %[[#]]
+  ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]]
+  ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_16]]
+  ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_16]] %[[#op_ext_glsl]] refract %[[#arg0]] %[[#arg1]]
+  %spv.refract = call <4 x half> @llvm.spv.refract.f16(<4 x half> %a, <4 x half> %b)
+  ret <4 x half> %spv.refract
+}
+
+define noundef <4 x float> @refract_float4(<4 x float> noundef %a, <4 x float> noundef %b, float %eta) {
+entry:
+  ; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#]]
+  ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
+  ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
+  ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] refract %[[#arg0]] %[[#arg1]]
+  %spv.refract = call <4 x float> @llvm.spv.refract.f32(<4 x float> %a, <4 x float> %b)
+  ret <4 x float> %spv.refract
+}
+
+declare <4 x half> @llvm.spv.refract.f16(<4 x half>, <4 x half>, half)
+declare <4 x float> @llvm.spv.refract.f32(<4 x float>, <4 x float>, float)

>From 7feddb7e715ef605552e3182fa0ae70a1de6f5ea Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Wed, 16 Apr 2025 13:48:06 -0700
Subject: [PATCH 02/23] Adding dxil tests

---
 clang/include/clang/Basic/BuiltinsSPIRV.td    | 39 ++++++++++
 clang/lib/CodeGen/TargetBuiltins/SPIR.cpp     | 15 ++++
 .../lib/Headers/hlsl/hlsl_intrinsic_helpers.h | 36 +++++++++
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      | 77 +++++++++++++++++++
 clang/lib/Sema/SemaSPIRV.cpp                  | 46 ++++++++++-
 clang/test/CodeGenHLSL/builtins/reflect.hlsl  | 32 ++++++--
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  1 +
 llvm/lib/IR/IRBuilder.cpp                     | 21 ++++-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  2 +
 9 files changed, 260 insertions(+), 9 deletions(-)
 create mode 100644 clang/include/clang/Basic/BuiltinsSPIRV.td

diff --git a/clang/include/clang/Basic/BuiltinsSPIRV.td b/clang/include/clang/Basic/BuiltinsSPIRV.td
new file mode 100644
index 0000000000000..c0f652b4f24e4
--- /dev/null
+++ b/clang/include/clang/Basic/BuiltinsSPIRV.td
@@ -0,0 +1,39 @@
+//===--- BuiltinsSPIRV.td - SPIRV Builtin function database ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+include "clang/Basic/BuiltinsBase.td"
+
+def SPIRVDistance : Builtin {
+  let Spellings = ["__builtin_spirv_distance"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
+def SPIRVLength : Builtin {
+  let Spellings = ["__builtin_spirv_length"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
+def SPIRVReflect : Builtin {
+  let Spellings = ["__builtin_spirv_reflect"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
+def SPIRVRefract : Builtin {
+  let Spellings = ["__builtin_spirv_refract"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
+def SPIRVSmoothStep : Builtin {
+  let Spellings = ["__builtin_spirv_smoothstep"];
+  let Attributes = [NoThrow, Const, CustomTypeChecking];
+  let Prototype = "void(...)";
+}
diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 0687485cd3f80..026444472a1ac 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
         /*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
         ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
   }
+  case SPIRV::BI__builtin_spirv_refract: {
+    Value *I = EmitScalarExpr(E->getArg(0));
+    Value *N = EmitScalarExpr(E->getArg(1));
+    Value *eta = EmitScalarExpr(E->getArg(2));
+    assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
+           E->getArg(1)->getType()->hasFloatingRepresentation() &&
+           E->getArg(2)->getType()->hasFloatingRepresentation() &&
+           "refract operands must have a float representation");
+    assert(E->getArg(0)->getType()->isVectorType() &&
+           E->getArg(1)->getType()->isVectorType() &&
+           "refract I and N operands must be a vector");
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/I->getType(), Intrinsic::spv_refract,
+        ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
+  }
   case SPIRV::BI__builtin_spirv_smoothstep: {
     Value *Min = EmitScalarExpr(E->getArg(0));
     Value *Max = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 4eb7b8f45c85a..62c7f67645e99 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -71,6 +71,42 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 #endif
 }
 
+template <typename T> constexpr T refract_impl(T I, T N, T eta) {
+  T k = 1 - eta * eta * (1 - (N * I * N *I));
+  if(k < 0)
+    return 0;
+  else
+    return (eta * I - (eta * N * I + sqrt(k)) * N);
+}
+
+template <typename T, int L>
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
+#if (__has_builtin(__builtin_spirv_refract))
+  return __builtin_spirv_refract(I, N, eta);
+#else
+  vector<T, L> k = 1 - eta * eta * (1 - dot(N, I) * dot(N, I));
+  if(k < 0)
+    return 0;
+  else
+    return (eta * I - (eta * dot(N, I) + sqrt(k)) * N);
+#endif
+}
+
+/*
+template <typename T, typename U> constexpr T refract_impl(T I, T N, U eta) {
+  return I - 2 * N * I * N;
+}
+
+template <typename T, int L>
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N) {
+#if (__has_builtin(__builtin_spirv_refract))
+  return __builtin_spirv_refract(I, N);
+#else
+  return I - 2 * N * dot(I, N);
+#endif
+}
+*/
+
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index ea880105fac3b..bed2ad1a498b2 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -475,6 +475,83 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
   return __detail::reflect_vec_impl(I, N);
 }
 
+//===----------------------------------------------------------------------===//
+// refract builtin
+//===----------------------------------------------------------------------===//
+
+/// \fn T refract(T I, T N, T eta)
+/// \brief Returns a refraction using an entering ray, \a I, a surface
+/// normal, \a N and refraction index \a eta
+/// \param I The entering ray.
+/// \param N The surface normal.
+/// \param eta The refraction index.
+///
+/// The return value is a floating-point vector that represents the refraction
+/// using the refraction index, \a eta, for the direction of the entering ray, \a I,
+/// off a surface with the normal \a N.
+///
+/// This function calculates the refraction vector using the following formulas:
+/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
+/// if k < 0.0 the result is 0.0
+/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
+///
+/// I and N must already be normalized in order to achieve the desired result.
+///
+/// I and N must be a scalar or vector whose component type is
+/// floating-point.
+///
+/// eta must be a 16-bit or 32-bit floating-point scalar.
+///
+/// Result type, the type of I, and the type of N must all be the same type.
+
+template <typename T>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
+                                       __detail::is_same<half, T>::value,
+                                   T> refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+/*
+template <typename T, typename U>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
+                                       __detail::is_same<half, T>::value,
+                                   T> refract(T I, T N, U eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <typename T, typename U>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::enable_if_t<__detail::is_arithmetic<U>::Value &&
+                                       __detail::is_same<half, T>::value,
+                                   T> refract(T I, T N, U eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+*/
+
+template <typename T>
+const inline __detail::enable_if_t<
+    __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
+    refract(T I, T N, T eta) {
+  return __detail::refract_impl(I, N, eta);
+}
+
+template <int L>
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
+    __detail::HLSL_FIXED_VECTOR<half, L> I,
+    __detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
+template <int L>
+const inline __detail::HLSL_FIXED_VECTOR<float, L>
+refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
+        __detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
+  return __detail::refract_vec_impl(I, N, eta);
+}
+
 //===----------------------------------------------------------------------===//
 // smoothstep builtin
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index c27d3fed2b990..8fbf30ccfe799 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -203,6 +203,50 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     TheCall->setType(RetTy);
     break;
   }
+  case SPIRV::BI__builtin_spirv_refract: {
+    if (SemaRef.checkArgCount(TheCall, 3))
+      return true;
+
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+    auto *VTyA = ArgTyA->getAs<VectorType>();
+    if (VTyA == nullptr) {
+      SemaRef.Diag(A.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyA
+          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+
+    ExprResult B = TheCall->getArg(1);
+    QualType ArgTyB = B.get()->getType();
+    auto *VTyB = ArgTyB->getAs<VectorType>();
+    if (VTyB == nullptr) {
+      SemaRef.Diag(B.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyB
+          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+
+    ExprResult C = TheCall->getArg(2);
+    QualType ArgTyC = C.get()->getType();
+    if (!ArgTyC->hasFloatingRepresentation()) {
+      SemaRef.Diag(C.get()->getBeginLoc(),
+                   diag::err_builtin_invalid_arg_type)
+          << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
+          << ArgTyC;
+      return true;
+    }
+
+    QualType RetTy = ArgTyA;
+    TheCall->setType(RetTy);
+    assert(RetTy == ArgTyA);
+    //assert(ArgTyB == ArgTyA);
+    break;
+  }
   case SPIRV::BI__builtin_spirv_reflect: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
@@ -223,7 +267,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     QualType ArgTyB = B.get()->getType();
     auto *VTyB = ArgTyB->getAs<VectorType>();
     if (VTyB == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
+      SemaRef.Diag(B.get()->getBeginLoc(),
                    diag::err_typecheck_convert_incompatible)
           << ArgTyB
           << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
diff --git a/clang/test/CodeGenHLSL/builtins/reflect.hlsl b/clang/test/CodeGenHLSL/builtins/reflect.hlsl
index 65fefd801ffed..730e0a831321d 100644
--- a/clang/test/CodeGenHLSL/builtins/reflect.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/reflect.hlsl
@@ -6,14 +6,32 @@
 // RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
 // RUN:   -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
 
-// CHECK-LABEL: define hidden noundef nofpclass(nan inf) half @_Z17test_reflect_halfDhDh(
-// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[eta:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[I]], 0xH4000
-// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[N]]
-// CHECK-NEXT:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[MUL_I]]
-// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[I]], [[MUL2_I]]
-// CHECK-NEXT:    ret half [[SUB_I]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[eta]], [[eta]]
+// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N:%.*]], %I
+// CHECK_NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0:%.*]], [[TMP0:%.*]]
+// CHECK_NEXT:    [[SUB1_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1:%.*]]
+// CHECK_NEXT:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %mul.i, [[SUB1_I:%.*]]
+// CHECK_NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I:%.*]]
+// CHECK_NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I:%.*]], 0xH0000
+// CHECK_NEXT:    br i1 [[CMP_I:%.*]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// CHECK_NEXT:  
+// CHECK_NEXT:    if.else.i:                                        ; preds = %entry
+// CHECK_NEXT:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[eta]], %I
+// CHECK_NEXT:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N:%.*]], %I
+// CHECK_NEXT:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I:%.*]], [[eta]]
+// CHECK_NEXT:    %2 = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I:%.*]])
+// CHECK_NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half %2, [[MUL8_I:%.*]]
+// CHECK_NEXT:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I:%.*]], [[N:%.*]]
+// CHECK_NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I:%.*]], [[MUL9_I:%.*]]
+// CHECK_NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// CHECK_NEXT:    
+// CHECK_NEXT:    _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK_NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I:%.*]], %if.else.i ], [ 0xH0000, %entry ]
+// CHECK_NEXT:    ret half [[RETVAL_0_I:%.*]]
+  
 //
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) half @_Z17test_reflect_halfDhDh(
 // SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 43335f81ed87f..8ee4618a47958 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -75,6 +75,7 @@ let TargetPrefix = "spv" in {
     [IntrNoMem] >;
   def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
   def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
+  def int_spv_refract : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
   def int_spv_reflect : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
   def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
   def int_spv_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 28037d7ec5616..8b1864e7c6710 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -865,8 +865,27 @@ CallInst *IRBuilderBase::CreateIntrinsic(Type *RetTy, Intrinsic::ID ID,
 
   SmallVector<Type *> ArgTys;
   ArgTys.reserve(Args.size());
-  for (auto &I : Args)
+  int i =0;
+  Type * Ity;
+  Type * Nty;
+  Type * etaty;
+
+  for (auto &I : Args) {
+    if(i ==0)
+      Ity = I->getType();
+    if(i ==1)
+      Nty = I->getType();
+    if(i ==2)
+      etaty = I->getType();
     ArgTys.push_back(I->getType());
+    i++;
+  }
+  //assert(Ity == RetTy);
+  //assert(Nty == RetTy);
+  assert(Nty == Ity);
+
+
+
   FunctionType *FTy = FunctionType::get(RetTy, ArgTys, false);
   SmallVector<Type *> OverloadTys;
   Intrinsic::MatchIntrinsicTypesResult Res =
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 40a0bd97adaf9..e07b923f92002 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -3094,6 +3094,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectExtInst(ResVReg, ResType, I, CL::fract, GL::Fract);
   case Intrinsic::spv_normalize:
     return selectExtInst(ResVReg, ResType, I, CL::normalize, GL::Normalize);
+  case Intrinsic::spv_refract:
+    return selectExtInst(ResVReg, ResType, I, GL::Refract);
   case Intrinsic::spv_reflect:
     return selectExtInst(ResVReg, ResType, I, GL::Reflect);
   case Intrinsic::spv_rsqrt:

>From aaa8f23f25281dffdb2348329b1d74f5d5b5ce80 Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Tue, 22 Apr 2025 16:47:25 -0700
Subject: [PATCH 03/23] Update CodeGen checks

---
 clang/test/CodeGenHLSL/builtins/reflect.hlsl |  32 +-
 clang/test/CodeGenHLSL/builtins/refract.hlsl | 396 +++++++++++++++----
 llvm/include/llvm/IR/IntrinsicsSPIRV.td      |   2 +-
 3 files changed, 338 insertions(+), 92 deletions(-)

diff --git a/clang/test/CodeGenHLSL/builtins/reflect.hlsl b/clang/test/CodeGenHLSL/builtins/reflect.hlsl
index 730e0a831321d..ba7a70b5185d5 100644
--- a/clang/test/CodeGenHLSL/builtins/reflect.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/reflect.hlsl
@@ -6,32 +6,14 @@
 // RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
 // RUN:   -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
-// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[eta:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_reflect_halfDhDh(
+// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[eta]], [[eta]]
-// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N:%.*]], %I
-// CHECK_NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0:%.*]], [[TMP0:%.*]]
-// CHECK_NEXT:    [[SUB1_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1:%.*]]
-// CHECK_NEXT:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %mul.i, [[SUB1_I:%.*]]
-// CHECK_NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I:%.*]]
-// CHECK_NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I:%.*]], 0xH0000
-// CHECK_NEXT:    br i1 [[CMP_I:%.*]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
-// CHECK_NEXT:  
-// CHECK_NEXT:    if.else.i:                                        ; preds = %entry
-// CHECK_NEXT:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[eta]], %I
-// CHECK_NEXT:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N:%.*]], %I
-// CHECK_NEXT:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I:%.*]], [[eta]]
-// CHECK_NEXT:    %2 = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I:%.*]])
-// CHECK_NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half %2, [[MUL8_I:%.*]]
-// CHECK_NEXT:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I:%.*]], [[N:%.*]]
-// CHECK_NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I:%.*]], [[MUL9_I:%.*]]
-// CHECK_NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
-// CHECK_NEXT:    
-// CHECK_NEXT:    _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
-// CHECK_NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I:%.*]], %if.else.i ], [ 0xH0000, %entry ]
-// CHECK_NEXT:    ret half [[RETVAL_0_I:%.*]]
-  
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[I]], 0xH4000
+// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[N]]
+// CHECK-NEXT:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[MUL_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[I]], [[MUL2_I]]
+// CHECK-NEXT:    ret half [[SUB_I]]
 //
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) half @_Z17test_reflect_halfDhDh(
 // SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
index 4dc5b66251b62..baeae1526545f 100644
--- a/clang/test/CodeGenHLSL/builtins/refract.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -6,87 +6,351 @@
 // RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
 // RUN:   -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_refract_floatff(
-// CHECK-SAME: float noundef nofpclass(nan inf) [[I:%.*]], float noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[I]], 2.000000e+00
-// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[N]], [[N]]
-// CHECK-NEXT:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[TMP0]], [[MUL_I]]
-// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float [[I]], [[MUL2_I]]
-// CHECK-NEXT:    ret float [[SUB_I]]
-//
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_refract_floatff(
-// SPVCHECK-SAME: float noundef nofpclass(nan inf) [[I:%.*]], float noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// CHECK-NEXT:    [[SUB1_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// CHECK-NEXT:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB1_I]]
+// CHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I]]
+// CHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// CHECK-NEXT:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// CHECK:  if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// CHECK-NEXT:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// CHECK-NEXT:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I]], [[ETA]]
+// CHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL8_I]]
+// CHECK-NEXT:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// CHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[MUL9_I]]
+// CHECK-NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// CHECK:  _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// CHECK-NEXT:    ret half [[RETVAL_0_I]]
+//
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] 
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// SPVCHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
+// SPVCHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
+// SPVCHECK-NEXT:    [[MUL_4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// SPVCHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_4_I]]
+// SPVCHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
+// SPVCHECK-NEXT:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// SPVCHECK:  if.else.i:                                        ; preds = %entry
+// SPVCHECK-NEXT:    [[MUL_6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
+// SPVCHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
+// SPVCHECK-NEXT:    [[MUL_8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_7_I]], [[ETA]]
+// SPVCHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
+// SPVCHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL_8_I]]
+// SPVCHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
+// SPVCHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL_6_I]], [[MUL_9_I]]
+// SPVCHECK-NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// SPVCHECK:  _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// SPVCHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
+// SPVCHECK-NEXT:    ret half [[RETVAL_0_I]]
+//
+half test_refract_half(half I, half N, half ETA) {
+    return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_
+// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]], [[CAST_VTRUNC]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> [[N]], <2 x half> [[I]])
+// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <2 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <2 x half> [[ETA]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[CAST_VTRUNC]]
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <2 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT10_I]], <2 x half> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK-NEXT:    br label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit
+// CHECK:  _ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <2 x half> [ %sub13.i, %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT:    ret <2 x half> [[RETVAL_0_I]]
+
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_(
+// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // SPVCHECK-NEXT:  [[ENTRY:.*:]]
-// SPVCHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[I]], 2.000000e+00
-// SPVCHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[N]], [[N]]
-// SPVCHECK-NEXT:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[TMP0]], [[MUL_I]]
-// SPVCHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float [[I]], [[MUL2_I]]
-// SPVCHECK-NEXT:    ret float [[SUB_I]]
-//
-float test_refract_float(float I, float N, float eta) {
-    return refract(I, N, eta);
+// SPVCHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
+// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]] to double
+// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f64(<2 x half> [[I]], <2 x half> [[N]], double [[CONV_I]])
+// SPVCHECK-NEXT:    ret <2 x half> [[SPV_REFRACT_I]]
+//
+half2 test_refract_half2(half2 I, half2 N, half2 ETA) {
+    return refract(I, N, ETA);
 }
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x float> @_Z19test_refract_float2Dv2_fS_(
-// CHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[I:%.*]], <2 x float> noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
+// CHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v2f32(<2 x float> [[I]], <2 x float> [[N]])
-// CHECK-NEXT:    [[DOTSCALAR:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], 2.000000e+00
-// CHECK-NEXT:    [[TMP0:%.*]] = insertelement <2 x float> poison, float [[DOTSCALAR]], i64 0
-// CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x float> [[TMP0]], <2 x float> poison, <2 x i32> zeroinitializer
-// CHECK-NEXT:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> [[TMP1]], [[N]]
-// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x float> [[I]], [[MUL1_I]]
-// CHECK-NEXT:    ret <2 x float> [[SUB_I]]
-//
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x float> @_Z19test_refract_float2Dv2_fS_(
-// SPVCHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[I:%.*]], <2 x float> noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> [[N]], <3 x half> [[I]])
+// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <3 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <3 x half> poison, half [[ETA]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT5_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <3 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT10_I]], <3 x half> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.sqrt.v3f16(<3 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT:   [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    br label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit
+// CHECK:  _ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <3 x half> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT:    ret <3 x half> [[RETVAL_0_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
+// SPVCHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // SPVCHECK-NEXT:  [[ENTRY:.*:]]
-// SPVCHECK-NEXT:    [[SPV_refract_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x float> @llvm.spv.refract.v2f32(<2 x float> [[I]], <2 x float> [[N]])
-// SPVCHECK-NEXT:    ret <2 x float> [[SPV_refract_I]]
+// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[ETA]] to double
+// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <3 x half> @llvm.spv.refract.v3f16.f64(<3 x half> [[I]], <3 x half> [[N]], double [[CONV_I]])
+// SPVCHECK-NEXT:    ret <3 x half> [[SPV_REFRACT_I]]
 //
-float2 test_refract_float2(float2 I, float2 N, float eta) {
-    return refract(I, N, eta);
+half3 test_refract_half3(half3 I, half3 N, half ETA) {
+    return refract(I, N, ETA);
 }
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x float> @_Z19test_refract_float3Dv3_fS_(
-// CHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[I:%.*]], <3 x float> noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-LABEL: define noundef nofpclass(nan inf) <4 x half> @_Z18test_refract_half4Dv4_DhS_Dh(
+// CHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[I:%.*]], <4 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v3f32(<3 x float> [[I]], <3 x float> [[N]])
-// CHECK-NEXT:    [[DOTSCALAR:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], 2.000000e+00
-// CHECK-NEXT:    [[TMP0:%.*]] = insertelement <3 x float> poison, float [[DOTSCALAR]], i64 0
-// CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <3 x float> [[TMP0]], <3 x float> poison, <3 x i32> zeroinitializer
-// CHECK-NEXT:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> [[TMP1]], [[N]]
-// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x float> [[I]], [[MUL1_I]]
-// CHECK-NEXT:    ret <3 x float> [[SUB_I]]
-//
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <3 x float> @_Z19test_refract_float3Dv3_fS_(
-// SPVCHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[I:%.*]], <3 x float> noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v4f16(<4 x half> [[N]], <4 x half> [[I]])
+// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
+// CHECK:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi4EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <4 x half> poison, half [[SUB4_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <4 x half> [[SPLAT_SPLATINSERT_I]], <4 x half> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <4 x half> poison, half [[ETA]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <4 x half> [[SPLAT_SPLATINSERT5_I]], <4 x half> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x half> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <4 x half> poison, half [[MUL_9_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <4 x half> [[SPLAT_SPLATINSERT10_I]], <4 x half> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <4 x half> @llvm.sqrt.v4f16(<4 x half> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <4 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x half> [[ADD_I]], [[N]]
+// CHECK-NEXT:   [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    br label %_ZN4hlsl8__detail16refract_vec_implIDhLi4EEEDvT0__T_S3_S3_S2_.exit
+// CHECK:  _ZN4hlsl8__detail16refract_vec_implIDhLi4EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <4 x half> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT:    ret <4 x half> [[RETVAL_0_I]]
+  
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <4 x half> @_Z18test_refract_half4Dv4_DhS_Dh(
+// SPVCHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[I:%.*]], <4 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // SPVCHECK-NEXT:  [[ENTRY:.*:]]
-// SPVCHECK-NEXT:    [[SPV_refract_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <3 x float> @llvm.spv.refract.v3f32(<3 x float> [[I]], <3 x float> [[N]])
-// SPVCHECK-NEXT:    ret <3 x float> [[SPV_refract_I]]
+// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[ETA]] to double
+// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <4 x half> @llvm.spv.refract.v4f16.f64(<4 x half> [[I]], <4 x half> [[N]], double [[CONV_I]])
+// SPVCHECK-NEXT:    ret <4 x half> [[SPV_REFRACT_I]]
 //
-float3 test_refract_float3(float3 I, float3 N, float eta) {
-    return refract(I, N, eta);
+half4 test_refract_half4(half4 I, half4 N, half ETA) {
+    return refract(I, N, ETA);
 }
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) <4 x float> @_Z19test_refract_float4Dv4_fS_(
-// CHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[I:%.*]], <4 x float> noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_refract_floatfff(
+// CHECK-SAME: float noundef nofpclass(nan inf) [[I:%.*]], float noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v4f32(<4 x float> [[I]], <4 x float> [[N]])
-// CHECK-NEXT:    [[DOTSCALAR:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], 2.000000e+00
-// CHECK-NEXT:    [[TMP0:%.*]] = insertelement <4 x float> poison, float [[DOTSCALAR]], i64 0
-// CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[TMP0]], <4 x float> poison, <4 x i32> zeroinitializer
-// CHECK-NEXT:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> [[TMP1]], [[N]]
-// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x float> [[I]], [[MUL1_I]]
-// CHECK-NEXT:    ret <4 x float> [[SUB_I]]
-//
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <4 x float> @_Z19test_refract_float4Dv4_fS_(
-// SPVCHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[I:%.*]], <4 x float> noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[N]], [[I]]
+// CHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[TMP0]], [[TMP0]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[TMP1]]
+// CHECK-NEXT:    [[MUL_4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_4_I]]
+// CHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float [[SUB5_I]], 0.000000e+00
+// CHECK:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIfEET_S2_S2_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[MUL_6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ETA]], [[I]]
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[N]], [[I]]
+// CHECK-NEXT:    [[MUL_8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_7_I]], [[ETA]]
+// CHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.sqrt.f32(float [[SUB5_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn float [[TMP2]], [[MUL_8_I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ADD_I]], [[N]]
+// CHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float [[MUL_6_I]], [[MUL_9_I]]
+// CHECK:    br label %_ZN4hlsl8__detail12refract_implIfEET_S2_S2_S2_.exit
+// CHECK: _ZN4hlsl8__detail12refract_implIfEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz float [ [[SUB10_I]], %if.else.i ], [ 0.000000e+00, %entry ]
+// CHECK-NEXT:    ret float [[RETVAL_0_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_refract_floatfff(
+// SPVCHECK-SAME: float noundef nofpclass(nan inf) [[I:%.*]], float noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // SPVCHECK-NEXT:  [[ENTRY:.*:]]
-// SPVCHECK-NEXT:    [[SPV_refract_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32(<4 x float> [[I]], <4 x float> [[N]])
-// SPVCHECK-NEXT:    ret <4 x float> [[SPV_refract_I]]
+// SPVCHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ETA]], [[ETA]]
+// SPVCHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[N]], [[I]]
+// SPVCHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[TMP0]], [[TMP0]]
+// SPVCHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[TMP1]]
+// SPVCHECK-NEXT:    [[MUL_4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
+// SPVCHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_4_I]]
+// SPVCHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float [[SUB5_I]], 0.000000e+00
+// SPVCHECK-NEXT:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIfEET_S2_S2_S2_.exit, label %if.else.i
+
+// SPVCHECK:  if.else.i:                                        ; preds = %entry
+// SPVCHECK-NEXT:    [[MUL_6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ETA]], [[I]]
+// SPVCHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[N]], [[I]]
+// SPVCHECK-NEXT:    [[MUL_8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_7_I]], [[ETA]]
+// SPVCHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.sqrt.f32(float [[SUB5_I]])
+// SPVCHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn float [[TMP2]], [[MUL_8_I]]
+// SPVCHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ADD_I]], [[N]]
+// SPVCHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float [[MUL_6_I]], [[MUL_9_I]]
+// SPVCHECK-NEXT:    br label %_ZN4hlsl8__detail12refract_implIfEET_S2_S2_S2_.exit
+// SPVCHECK:  _ZN4hlsl8__detail12refract_implIfEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// SPVCHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz float [ [[SUB10_I]], %if.else.i ], [ 0.000000e+00, %entry ]
+// SPVCHECK-NEXT:    ret float [[RETVAL_0_I]]
 //
-float4 test_refract_float4(float4 I, float4 N, float eta) {
-    return refract(I, N, eta);
+float test_refract_float(float I, float N, float ETA) {
+    return refract(I, N, ETA);
 }
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x float> @_Z19test_refract_float2Dv2_fS_f(
+// CHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[I:%.*]], <2 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v2f32(<2 x float> [[N]], <2 x float> [[I]])
+// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_2_I]]
+// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_3_I]]
+// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float [[SUB4_I]], 0.000000e+00
+// CHECK-NEXT:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIfLi2EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <2 x float> poison, float [[SUB4_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT_I]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <2 x float> poison, float [[ETA]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT5_I]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[ETA]]
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <2 x float> poison, float [[MUL_9_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT10_I]], <2 x float> poison, <2 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <2 x float> @llvm.sqrt.v2f32(<2 x float> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x float> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> [[ADD_I]], [[N]]
+// CHECK-NEXT:   [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x float> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK-NEXT:    br label %_ZN4hlsl8__detail16refract_vec_implIfLi2EEEDvT0__T_S3_S3_S2_.exit
+// CHECK:   _ZN4hlsl8__detail16refract_vec_implIfLi2EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <2 x float> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT:    ret <2 x float> [[RETVAL_0_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x float> @_Z19test_refract_float2Dv2_fS_f(
+// SPVCHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[I:%.*]], <2 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn float [[ETA]] to double
+// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x float> @llvm.spv.refract.v2f32.f64(<2 x float> [[I]], <2 x float> [[N]], double [[CONV_I]])
+// SPVCHECK-NEXT:    ret <2 x float> [[SPV_REFRACT_I]]
+//
+float2 test_refract_float2(float2 I, float2 N, float ETA) {
+    return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x float> @_Z19test_refract_float3Dv3_fS_f(
+// CHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[I:%.*]], <3 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v3f32(<3 x float> [[N]], <3 x float> [[I]])
+// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_2_I]]
+// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_3_I]]
+// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float [[SUB4_I]], 0.000000e+00
+// CHECK-NEXT:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIfLi3EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <3 x float> poison, float [[SUB4_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT_I]], <3 x float> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <3 x float> poison, float [[ETA]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT5_I]], <3 x float> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[ETA]]
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <3 x float> poison, float [[MUL_9_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT10_I]], <3 x float> poison, <3 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.sqrt.v3f32(<3 x float> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x float> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> [[ADD_I]], [[N]]
+// CHECK-NEXT:   [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x float> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK-NEXT:    br label %_ZN4hlsl8__detail16refract_vec_implIfLi3EEEDvT0__T_S3_S3_S2_.exit
+// CHECK:   _ZN4hlsl8__detail16refract_vec_implIfLi3EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <3 x float> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT:    ret <3 x float> [[RETVAL_0_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <3 x float> @_Z19test_refract_float3Dv3_fS_f(
+// SPVCHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[I:%.*]], <3 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn float [[ETA]] to double
+// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <3 x float> @llvm.spv.refract.v3f32.f64(<3 x float> [[I]], <3 x float> [[N]], double [[CONV_I]])
+// SPVCHECK-NEXT:    ret <3 x float> [[SPV_REFRACT_I]]
+//
+float3 test_refract_float3(float3 I, float3 N, float ETA) {
+    return refract(I, N, ETA);
+}
+
+// CHECK-LABEL: define noundef nofpclass(nan inf) <4 x float> @_Z19test_refract_float4Dv4_fS_f
+// CHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[I:%.*]], <4 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ETA]], [[ETA]]
+// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v4f32(<4 x float> [[N]], <4 x float> [[I]])
+// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[HLSL_DOT_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_2_I]]
+// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
+// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_3_I]]
+// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float [[SUB4_I]], 0.000000e+00
+// CHECK-NEXT:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIfLi4EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
+// CHECK:   if.else.i:                                        ; preds = %entry
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <4 x float> poison, float [[SUB4_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <4 x float> [[SPLAT_SPLATINSERT_I]], <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <4 x float> poison, float [[ETA]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <4 x float> [[SPLAT_SPLATINSERT5_I]], <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> [[SPLAT_SPLAT6_I]], [[I]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[ETA]]
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <4 x float> poison, float [[MUL_9_I]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <4 x float> [[SPLAT_SPLATINSERT10_I]], <4 x float> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <4 x float> @llvm.sqrt.v4f32(<4 x float> [[SPLAT_SPLAT_I]])
+// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <4 x float> [[TMP0]], [[SPLAT_SPLAT11_I]]
+// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> [[ADD_I]], [[N]]
+// CHECK-NEXT:   [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x float> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK-NEXT:    br label %_ZN4hlsl8__detail16refract_vec_implIfLi4EEEDvT0__T_S3_S3_S2_.exit
+// CHECK:   _ZN4hlsl8__detail16refract_vec_implIfLi4EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <4 x float> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
+// CHECK-NEXT:    ret <4 x float> [[RETVAL_0_I]]
+//
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <4 x float> @_Z19test_refract_float4Dv4_fS_f(
+// SPVCHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[I]], <4 x float> noundef nofpclass(nan inf) [[N]], float noundef nofpclass(nan inf) [[ETA]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// SPVCHECK-NEXT:  [[ENTRY:.*:]]
+// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn float [[ETA]] to double
+// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32.f64(<4 x float> [[I]], <4 x float> [[N]], double [[CONV_I]])
+// SPVCHECK-NEXT:    ret <4 x float> [[SPV_REFRACT_I]]
+//
+float4 test_refract_float4(float4 I, float4 N, float ETA) {
+    return refract(I, N, ETA);
+}
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 8ee4618a47958..2f20d11a49219 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -75,7 +75,7 @@ let TargetPrefix = "spv" in {
     [IntrNoMem] >;
   def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
   def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
-  def int_spv_refract : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
+  def int_spv_refract : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_anyfloat_ty], [IntrNoMem]>;
   def int_spv_reflect : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
   def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
   def int_spv_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;

>From 70646d2fad8977a5000799269d8901515ed01d39 Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Mon, 28 Apr 2025 16:39:20 -0700
Subject: [PATCH 04/23] fix tests

---
 .../lib/Headers/hlsl/hlsl_intrinsic_helpers.h | 15 ----------
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      | 18 ------------
 clang/lib/Sema/SemaSPIRV.cpp                  |  1 -
 clang/test/CodeGenSPIRV/Builtins/refract.c    | 21 ++++++++------
 .../test/SemaSPIRV/BuiltIns/refract-errors.c  | 28 +++++++++++++++++++
 llvm/lib/IR/IRBuilder.cpp                     | 21 +-------------
 .../CodeGen/SPIRV/hlsl-intrinsics/refract.ll  | 28 +++++++++++--------
 7 files changed, 57 insertions(+), 75 deletions(-)
 create mode 100644 clang/test/SemaSPIRV/BuiltIns/refract-errors.c

diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 62c7f67645e99..08e9e04c492e2 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -92,21 +92,6 @@ constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
 #endif
 }
 
-/*
-template <typename T, typename U> constexpr T refract_impl(T I, T N, U eta) {
-  return I - 2 * N * I * N;
-}
-
-template <typename T, int L>
-constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N) {
-#if (__has_builtin(__builtin_spirv_refract))
-  return __builtin_spirv_refract(I, N);
-#else
-  return I - 2 * N * dot(I, N);
-#endif
-}
-*/
-
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index bed2ad1a498b2..13226070a1928 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -512,24 +512,6 @@ const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
   return __detail::refract_impl(I, N, eta);
 }
 
-/*
-template <typename T, typename U>
-_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
-const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
-                                       __detail::is_same<half, T>::value,
-                                   T> refract(T I, T N, U eta) {
-  return __detail::refract_impl(I, N, eta);
-}
-
-template <typename T, typename U>
-_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
-const inline __detail::enable_if_t<__detail::is_arithmetic<U>::Value &&
-                                       __detail::is_same<half, T>::value,
-                                   T> refract(T I, T N, U eta) {
-  return __detail::refract_impl(I, N, eta);
-}
-*/
-
 template <typename T>
 const inline __detail::enable_if_t<
     __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 8fbf30ccfe799..ae99d8251f1dd 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -244,7 +244,6 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     QualType RetTy = ArgTyA;
     TheCall->setType(RetTy);
     assert(RetTy == ArgTyA);
-    //assert(ArgTyB == ArgTyA);
     break;
   }
   case SPIRV::BI__builtin_spirv_reflect: {
diff --git a/clang/test/CodeGenSPIRV/Builtins/refract.c b/clang/test/CodeGenSPIRV/Builtins/refract.c
index 06498554bd4d1..5a3beb23a4ed2 100644
--- a/clang/test/CodeGenSPIRV/Builtins/refract.c
+++ b/clang/test/CodeGenSPIRV/Builtins/refract.c
@@ -7,26 +7,29 @@ typedef float float3 __attribute__((ext_vector_type(3)));
 typedef float float4 __attribute__((ext_vector_type(4)));
 
 // CHECK-LABEL: define spir_func <2 x float> @test_refract_float2(
-// CHECK-SAME: <2 x float> noundef [[X:%.*]], <2 x float> noundef [[Y:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-SAME: <2 x float> noundef [[I:%.*]], <2 x float> noundef [[N:%.*]], float noundef [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <2 x float> @llvm.spv.refract.v2f32(<2 x float> [[X]], <2 x float> [[Y]])
+// CHECK-NEXT:    [[CONV:%.*]] = fpext float [[ETA]] to double
+// CHECK:    [[SPV_REFRACT:%.*]] = tail call <2 x float> @llvm.spv.refract.v2f32.f64(<2 x float> [[I]], <2 x float> [[N]], double [[CONV]])
 // CHECK-NEXT:    ret <2 x float> [[SPV_REFRACT]]
 //
-float2 test_refract_float2(float2 X, float2 Y, float eta) { return __builtin_spirv_refract(X, Y, eta); }
+float2 test_refract_float2(float2 I, float2 N, float eta) { return __builtin_spirv_refract(I, N, eta); }
 
 // CHECK-LABEL: define spir_func <3 x float> @test_refract_float3(
-// CHECK-SAME: <3 x float> noundef [[X:%.*]], <3 x float> noundef [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-SAME: <3 x float> noundef [[I:%.*]], <3 x float> noundef [[N:%.*]], float noundef [[ETA:%.*]]) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <3 x float> @llvm.spv.refract.v3f32(<3 x float> [[X]], <3 x float> [[Y]])
+// CHECK-NEXT:    [[CONV:%.*]] = fpext float [[ETA]] to double
+// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <3 x float> @llvm.spv.refract.v3f32.f64(<3 x float> [[I]], <3 x float> [[N]], double [[CONV]])
 // CHECK-NEXT:    ret <3 x float> [[SPV_REFRACT]]
 //
-float3 test_refract_float3(float3 X, float3 Y, float eta) { return __builtin_spirv_refract(X, Y, eta); }
+float3 test_refract_float3(float3 I, float3 N, float eta) { return __builtin_spirv_refract(I, N, eta); }
 
 // CHECK-LABEL: define spir_func <4 x float> @test_refract_float4(
-// CHECK-SAME: <4 x float> noundef [[X:%.*]], <4 x float> noundef [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
+// CHECK-SAME: <4 x float> noundef [[I:%.*]], <4 x float> noundef [[N:%.*]], float noundef [[ETA:%.*]]) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <4 x float> @llvm.spv.refract.v4f32(<4 x float> [[X]], <4 x float> [[Y]])
+// CHECK-NEXT:    [[CONV:%.*]] = fpext float [[ETA]] to double
+// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <4 x float> @llvm.spv.refract.v4f32.f64(<4 x float> [[I]], <4 x float> [[N]], double [[CONV]])
 // CHECK-NEXT:    ret <4 x float> [[SPV_REFRACT]]
 //
-float4 test_refract_float4(float4 X, float4 Y, float eta) { return __builtin_spirv_refract(X, Y, eta); }
+float4 test_refract_float4(float4 I, float4 N, float eta) { return __builtin_spirv_refract(I, N, eta); }
 
diff --git a/clang/test/SemaSPIRV/BuiltIns/refract-errors.c b/clang/test/SemaSPIRV/BuiltIns/refract-errors.c
new file mode 100644
index 0000000000000..1baea9dca303e
--- /dev/null
+++ b/clang/test/SemaSPIRV/BuiltIns/refract-errors.c
@@ -0,0 +1,28 @@
+// RUN: %clang_cc1 %s -triple spirv-pc-vulkan-compute -verify
+
+typedef float float2 __attribute__((ext_vector_type(2)));
+
+float2 test_no_second_arg(float2 p0) {
+  return __builtin_spirv_refract(p0);
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 1}}
+}
+
+float2 test_too_few_arg(float2 p0) {
+  return __builtin_spirv_refract(p0, p0);
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 2}}
+}
+
+float2 test_too_many_arg(float2 p0, float p1) {
+  return __builtin_spirv_refract(p0, p0, p1, p1);
+  // expected-error at -1 {{too many arguments to function call, expected 3, have 4}}
+}
+
+float test_double_scalar_inputs(double p0, double p1, double p2) {
+  return __builtin_spirv_refract(p0, p1, p2);
+  //  expected-error at -1 {{passing 'double' to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(double)))) double' (vector of 2 'double' values)}}
+}
+
+float test_int_scalar_inputs(int p0, int p1, int p2) {
+  return __builtin_spirv_refract(p0, p1, p2);
+  //  expected-error at -1 {{passing 'int' to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(int)))) int' (vector of 2 'int' values)}}
+}
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 8b1864e7c6710..28037d7ec5616 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -865,27 +865,8 @@ CallInst *IRBuilderBase::CreateIntrinsic(Type *RetTy, Intrinsic::ID ID,
 
   SmallVector<Type *> ArgTys;
   ArgTys.reserve(Args.size());
-  int i =0;
-  Type * Ity;
-  Type * Nty;
-  Type * etaty;
-
-  for (auto &I : Args) {
-    if(i ==0)
-      Ity = I->getType();
-    if(i ==1)
-      Nty = I->getType();
-    if(i ==2)
-      etaty = I->getType();
+  for (auto &I : Args)
     ArgTys.push_back(I->getType());
-    i++;
-  }
-  //assert(Ity == RetTy);
-  //assert(Nty == RetTy);
-  assert(Nty == Ity);
-
-
-
   FunctionType *FTy = FunctionType::get(RetTy, ArgTys, false);
   SmallVector<Type *> OverloadTys;
   Intrinsic::MatchIntrinsicTypesResult Res =
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
index 7dc6d24f651de..48b40dd8ba15a 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
@@ -7,27 +7,31 @@
 ; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
 ; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
 ; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_64:]] = OpTypeFloat 64
 ; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
 
-define noundef <4 x half> @refract_half4(<4 x half> noundef %a, <4 x half> noundef %b, half %eta) {
+define noundef  <4 x half> @refract_half(<4 x half> noundef  %I, <4 x half> noundef  %N, half noundef  %ETA) {
 entry:
   ; CHECK: %[[#]] = OpFunction %[[#vec4_float_16]] None %[[#]]
   ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]]
   ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_16]]
-  ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_16]] %[[#op_ext_glsl]] refract %[[#arg0]] %[[#arg1]]
-  %spv.refract = call <4 x half> @llvm.spv.refract.f16(<4 x half> %a, <4 x half> %b)
-  ret <4 x half> %spv.refract
+  ; CHECK: %[[#arg2_float_16:]] = OpFunctionParameter %[[#float_16:]]
+  ; CHECK: %[[#arg2:]] = OpFConvert %[[#float_64:]] %[[#arg2_float_16:]]
+  ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_16]] %[[#op_ext_glsl]] Refract %[[#arg0]] %[[#arg1]] %[[#arg2]]
+  %conv.i = fpext reassoc nnan ninf nsz arcp afn half %ETA to double
+  %spv.refract.i = tail call reassoc nnan ninf nsz arcp afn noundef <4 x half> @llvm.spv.refract.v4f16.f64(<4 x half> %I, <4 x half> %N, double %conv.i)
+  ret <4 x half> %spv.refract.i
 }
 
-define noundef <4 x float> @refract_float4(<4 x float> noundef %a, <4 x float> noundef %b, float %eta) {
+define noundef  <4 x float> @refract_float4(<4 x float> noundef  %I, <4 x float> noundef  %N, float noundef  %ETA) {
 entry:
+  %conv.i = fpext reassoc nnan ninf nsz arcp afn float %ETA to double
   ; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#]]
   ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
   ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
-  ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] refract %[[#arg0]] %[[#arg1]]
-  %spv.refract = call <4 x float> @llvm.spv.refract.f32(<4 x float> %a, <4 x float> %b)
-  ret <4 x float> %spv.refract
-}
-
-declare <4 x half> @llvm.spv.refract.f16(<4 x half>, <4 x half>, half)
-declare <4 x float> @llvm.spv.refract.f32(<4 x float>, <4 x float>, float)
+  ; CHECK: %[[#arg2_float_32:]] = OpFunctionParameter %[[#float_32:]]
+  ; CHECK: %[[#arg2:]] = OpFConvert %[[#float_64:]] %[[#arg2_float_32:]]
+  ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] Refract %[[#arg0]] %[[#arg1]] %[[#arg2]]
+  %spv.refract.i = tail call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32.f64(<4 x float> %I, <4 x float> %N, double %conv.i)
+  ret <4 x float> %spv.refract.i
+}
\ No newline at end of file

>From 9b0ccaf73429dd3186fddf8b24708bd4e2dcc39a Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Mon, 28 Apr 2025 17:15:02 -0700
Subject: [PATCH 05/23] adding new lines to EOF

---
 clang/test/CodeGenHLSL/builtins/refract.hlsl       | 2 +-
 clang/test/CodeGenSPIRV/Builtins/refract.c         | 1 -
 clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl   | 2 +-
 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll | 2 +-
 4 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
index baeae1526545f..a2e160f17b582 100644
--- a/clang/test/CodeGenHLSL/builtins/refract.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -353,4 +353,4 @@ float3 test_refract_float3(float3 I, float3 N, float ETA) {
 //
 float4 test_refract_float4(float4 I, float4 N, float ETA) {
     return refract(I, N, ETA);
-}
\ No newline at end of file
+}
diff --git a/clang/test/CodeGenSPIRV/Builtins/refract.c b/clang/test/CodeGenSPIRV/Builtins/refract.c
index 5a3beb23a4ed2..82f620fa293de 100644
--- a/clang/test/CodeGenSPIRV/Builtins/refract.c
+++ b/clang/test/CodeGenSPIRV/Builtins/refract.c
@@ -32,4 +32,3 @@ float3 test_refract_float3(float3 I, float3 N, float eta) { return __builtin_spi
 // CHECK-NEXT:    ret <4 x float> [[SPV_REFRACT]]
 //
 float4 test_refract_float4(float4 I, float4 N, float eta) { return __builtin_spirv_refract(I, N, eta); }
-
diff --git a/clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
index eee913e6bb6e7..fe094591bea6a 100644
--- a/clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
@@ -71,4 +71,4 @@ float5 test_vec5_inputs(float5 p0, float5 p1,  float p2) {
   // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: deduced conflicting types for parameter 'T' ('float5' (vector of 5 'float' values) vs. 'float')}}
   // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: substitution failure [with L = 5]: no type named 'Type' in 'hlsl::__detail::enable_if<false, half>'}}
   // expected-note at hlsl/hlsl_intrinsics.h:* {{candidate template ignored: substitution failure [with L = 5]: no type named 'Type' in 'hlsl::__detail::enable_if<false, float>'}}
-}
\ No newline at end of file
+}
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
index 48b40dd8ba15a..780f7bfe60db1 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
@@ -34,4 +34,4 @@ entry:
   ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] Refract %[[#arg0]] %[[#arg1]] %[[#arg2]]
   %spv.refract.i = tail call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32.f64(<4 x float> %I, <4 x float> %N, double %conv.i)
   ret <4 x float> %spv.refract.i
-}
\ No newline at end of file
+}

>From b46a32f163668f07c05e31f156669585756b3ff4 Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Thu, 1 May 2025 17:26:32 -0700
Subject: [PATCH 06/23] Addressing partial review comments

---
 .../lib/Headers/hlsl/hlsl_intrinsic_helpers.h | 20 +++++++++----------
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  6 +++---
 clang/lib/Sema/SemaSPIRV.cpp                  |  6 ++----
 clang/test/CodeGenSPIRV/Builtins/refract.c    |  2 +-
 .../test/SemaSPIRV/BuiltIns/refract-errors.c  |  2 +-
 5 files changed, 17 insertions(+), 19 deletions(-)

diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 08e9e04c492e2..6b6295a3fe32e 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -10,7 +10,7 @@
 #define _HLSL_HLSL_INTRINSIC_HELPERS_H_
 
 namespace hlsl {
-namespace __detail {
+namespace __dETAil {
 
 constexpr vector<uint, 4> d3d_color_to_ubyte4_impl(vector<float, 4> V) {
   // Use the same scaling factor used by FXC, and DXC for DXIL
@@ -71,24 +71,24 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 #endif
 }
 
-template <typename T> constexpr T refract_impl(T I, T N, T eta) {
-  T k = 1 - eta * eta * (1 - (N * I * N *I));
-  if(k < 0)
+template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
+  T K = 1 - Eta * Eta * (1 - (N * I * N * I));
+  if (K < 0)
     return 0;
   else
-    return (eta * I - (eta * N * I + sqrt(k)) * N);
+    return (Eta * I - (Eta * N * I + sqrt(K)) * N);
 }
 
 template <typename T, int L>
-constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T eta) {
+constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
 #if (__has_builtin(__builtin_spirv_refract))
-  return __builtin_spirv_refract(I, N, eta);
+  return __builtin_spirv_refract(I, N, Eta);
 #else
-  vector<T, L> k = 1 - eta * eta * (1 - dot(N, I) * dot(N, I));
-  if(k < 0)
+  vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I));
+  if (K < 0)
     return 0;
   else
-    return (eta * I - (eta * dot(N, I) + sqrt(k)) * N);
+    return (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
 #endif
 }
 
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 13226070a1928..8c262ffce25f1 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -487,8 +487,8 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
 /// \param eta The refraction index.
 ///
 /// The return value is a floating-point vector that represents the refraction
-/// using the refraction index, \a eta, for the direction of the entering ray, \a I,
-/// off a surface with the normal \a N.
+/// using the refraction index, \a eta, for the direction of the entering ray,
+/// \a I, off a surface with the normal \a N.
 ///
 /// This function calculates the refraction vector using the following formulas:
 /// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
@@ -515,7 +515,7 @@ const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
 template <typename T>
 const inline __detail::enable_if_t<
     __detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
-    refract(T I, T N, T eta) {
+refract(T I, T N, T eta) {
   return __detail::refract_impl(I, N, eta);
 }
 
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index ae99d8251f1dd..278ee8acd56d0 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -173,7 +173,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     QualType ArgTyB = B.get()->getType();
     auto *VTyB = ArgTyB->getAs<VectorType>();
     if (VTyB == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
+      SemaRef.Diag(B.get()->getBeginLoc(),
                    diag::err_typecheck_convert_incompatible)
           << ArgTyB
           << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
@@ -234,8 +234,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     ExprResult C = TheCall->getArg(2);
     QualType ArgTyC = C.get()->getType();
     if (!ArgTyC->hasFloatingRepresentation()) {
-      SemaRef.Diag(C.get()->getBeginLoc(),
-                   diag::err_builtin_invalid_arg_type)
+      SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
           << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
           << ArgTyC;
       return true;
@@ -243,7 +242,6 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
 
     QualType RetTy = ArgTyA;
     TheCall->setType(RetTy);
-    assert(RetTy == ArgTyA);
     break;
   }
   case SPIRV::BI__builtin_spirv_reflect: {
diff --git a/clang/test/CodeGenSPIRV/Builtins/refract.c b/clang/test/CodeGenSPIRV/Builtins/refract.c
index 82f620fa293de..a867b20ae123a 100644
--- a/clang/test/CodeGenSPIRV/Builtins/refract.c
+++ b/clang/test/CodeGenSPIRV/Builtins/refract.c
@@ -1,6 +1,6 @@
 // NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
 
-// RUN: %clang_cc1 -O1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s
+// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s
 
 typedef float float2 __attribute__((ext_vector_type(2)));
 typedef float float3 __attribute__((ext_vector_type(3)));
diff --git a/clang/test/SemaSPIRV/BuiltIns/refract-errors.c b/clang/test/SemaSPIRV/BuiltIns/refract-errors.c
index 1baea9dca303e..9f6589bb05b6a 100644
--- a/clang/test/SemaSPIRV/BuiltIns/refract-errors.c
+++ b/clang/test/SemaSPIRV/BuiltIns/refract-errors.c
@@ -7,7 +7,7 @@ float2 test_no_second_arg(float2 p0) {
   // expected-error at -1 {{too few arguments to function call, expected 3, have 1}}
 }
 
-float2 test_too_few_arg(float2 p0) {
+float2 test_no_third_arg(float2 p0) {
   return __builtin_spirv_refract(p0, p0);
   // expected-error at -1 {{too few arguments to function call, expected 3, have 2}}
 }

>From 390a4c7058d36fe8f04aa373b86fee6c74de6f20 Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Wed, 14 May 2025 12:41:21 -0700
Subject: [PATCH 07/23] modularize SemaSPIRV.cpp

---
 clang/lib/CodeGen/TargetBuiltins/SPIR.cpp     |   2 +-
 .../lib/Headers/hlsl/hlsl_intrinsic_helpers.h |  15 +--
 clang/lib/Sema/SemaSPIRV.cpp                  | 124 ++++++------------
 clang/test/CodeGenHLSL/builtins/reflect.hlsl  |   1 -
 clang/test/CodeGenHLSL/builtins/refract.hlsl  |  17 +--
 clang/test/CodeGenSPIRV/Builtins/refract.c    |   2 +-
 .../CodeGen/SPIRV/opencl/refract-error.ll     |  12 ++
 7 files changed, 69 insertions(+), 104 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/opencl/refract-error.ll

diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 026444472a1ac..1c63e04f757c7 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -64,7 +64,7 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
     Value *eta = EmitScalarExpr(E->getArg(2));
     assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
            E->getArg(1)->getType()->hasFloatingRepresentation() &&
-           E->getArg(2)->getType()->hasFloatingRepresentation() &&
+           E->getArg(2)->getType()->isFloatingType() &&
            "refract operands must have a float representation");
     assert(E->getArg(0)->getType()->isVectorType() &&
            E->getArg(1)->getType()->isVectorType() &&
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 6b6295a3fe32e..5d7cd54a41422 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -10,7 +10,7 @@
 #define _HLSL_HLSL_INTRINSIC_HELPERS_H_
 
 namespace hlsl {
-namespace __dETAil {
+namespace __detail {
 
 constexpr vector<uint, 4> d3d_color_to_ubyte4_impl(vector<float, 4> V) {
   // Use the same scaling factor used by FXC, and DXC for DXIL
@@ -73,10 +73,8 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 
 template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
   T K = 1 - Eta * Eta * (1 - (N * I * N * I));
-  if (K < 0)
-    return 0;
-  else
-    return (Eta * I - (Eta * N * I + sqrt(K)) * N);
+  T Result = (Eta * I - (Eta * N * I + sqrt(K)) * N);
+  return select<T>(K < 0, static_cast<T>(0), Result);
 }
 
 template <typename T, int L>
@@ -85,13 +83,12 @@ constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
   return __builtin_spirv_refract(I, N, Eta);
 #else
   vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I));
-  if (K < 0)
-    return 0;
-  else
-    return (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
+  vector<T, L> Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
+  return select<vector<T, L>>(K < 0, vector<T, L>(0), Result);
 #endif
 }
 
+
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 278ee8acd56d0..d8833161e0673 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -29,6 +29,31 @@ namespace clang {
 
 SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {}
 
+/// Checks if the first `NumArgsToCheck` arguments of a function call are of vector type.
+/// If any of the arguments is not a vector type, it emits a diagnostic error and returns `true`.
+/// Otherwise, it returns `false`.
+///
+/// \param TheCall The function call expression to check.
+/// \param NumArgsToCheck The number of arguments to check for vector type.
+/// \return `true` if any of the arguments is not a vector type, `false` otherwise.
+
+bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
+  for (unsigned i = 0; i < NumArgsToCheck; ++i) {
+    ExprResult Arg = TheCall->getArg(i);
+    QualType ArgTy = Arg.get()->getType();
+    auto *VTy = ArgTy->getAs<VectorType>();
+    if (VTy == nullptr) {
+      SemaRef.Diag(Arg.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTy
+          << SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+  }
+  return false;
+}
+
 static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() > 1);
   QualType ArgTy0 = TheCall->getArg(0)->getType();
@@ -45,6 +70,7 @@ static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
   }
   return false;
 }
+bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
 
 static std::optional<int>
 processConstant32BitIntArgument(Sema &SemaRef, CallExpr *Call, int Argument) {
@@ -157,49 +183,23 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
 
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTyA = ArgTyA->getAs<VectorType>();
-    if (VTyA == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
-      return true;
-    }
-
-    ExprResult B = TheCall->getArg(1);
-    QualType ArgTyB = B.get()->getType();
-    auto *VTyB = ArgTyB->getAs<VectorType>();
-    if (VTyB == nullptr) {
-      SemaRef.Diag(B.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyB
-          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+    // Use the helper function to check both arguments
+    if (CheckVectorArgs(TheCall, 2))
       return true;
-    }
 
-    QualType RetTy = VTyA->getElementType();
+    QualType RetTy = TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
     TheCall->setType(RetTy);
     break;
   }
   case SPIRV::BI__builtin_spirv_length: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTy = ArgTyA->getAs<VectorType>();
-    if (VTy == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+
+    // Use the helper function to check the argument
+    if (CheckVectorArgs(TheCall, 1))
       return true;
-    }
-    QualType RetTy = VTy->getElementType();
+
+    QualType RetTy = TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
     TheCall->setType(RetTy);
     break;
   }
@@ -207,40 +207,20 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
 
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTyA = ArgTyA->getAs<VectorType>();
-    if (VTyA == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+    // Use the helper function to check the first two arguments
+    if (CheckVectorArgs(TheCall, 2))
       return true;
-    }
-
-    ExprResult B = TheCall->getArg(1);
-    QualType ArgTyB = B.get()->getType();
-    auto *VTyB = ArgTyB->getAs<VectorType>();
-    if (VTyB == nullptr) {
-      SemaRef.Diag(B.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyB
-          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
-          << 0 << 0;
-      return true;
-    }
 
     ExprResult C = TheCall->getArg(2);
     QualType ArgTyC = C.get()->getType();
-    if (!ArgTyC->hasFloatingRepresentation()) {
+    if (!ArgTyC->isFloatingType()) {
       SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
-          << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
+          << 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1
           << ArgTyC;
       return true;
     }
 
-    QualType RetTy = ArgTyA;
+    QualType RetTy = TheCall->getArg(0)->getType();
     TheCall->setType(RetTy);
     break;
   }
@@ -248,31 +228,11 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
 
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTyA = ArgTyA->getAs<VectorType>();
-    if (VTyA == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
-      return true;
-    }
-
-    ExprResult B = TheCall->getArg(1);
-    QualType ArgTyB = B.get()->getType();
-    auto *VTyB = ArgTyB->getAs<VectorType>();
-    if (VTyB == nullptr) {
-      SemaRef.Diag(B.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyB
-          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+    // Use the helper function to check both arguments
+    if (CheckVectorArgs(TheCall, 2))
       return true;
-    }
 
-    QualType RetTy = ArgTyA;
+    QualType RetTy = TheCall->getArg(0)->getType();
     TheCall->setType(RetTy);
     break;
   }
diff --git a/clang/test/CodeGenHLSL/builtins/reflect.hlsl b/clang/test/CodeGenHLSL/builtins/reflect.hlsl
index ba7a70b5185d5..351ffac1c1283 100644
--- a/clang/test/CodeGenHLSL/builtins/reflect.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/reflect.hlsl
@@ -1,4 +1,3 @@
-// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
 // RUN: %clang_cc1 -finclude-default-header -triple \
 // RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
 // RUN:   -emit-llvm -O1 -o - | FileCheck %s
diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
index a2e160f17b582..e396e372c9aa1 100644
--- a/clang/test/CodeGenHLSL/builtins/refract.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -1,4 +1,3 @@
-// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
 // RUN: %clang_cc1 -finclude-default-header -triple \
 // RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
 // RUN:   -emit-llvm -O1 -o - | FileCheck %s
@@ -59,11 +58,10 @@ half test_refract_half(half I, half N, half ETA) {
     return refract(I, N, ETA);
 }
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_
-// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_Dh(
+// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
-// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]], [[CAST_VTRUNC]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]], [[ETA]]
 // CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> [[N]], <2 x half> [[I]])
 // CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
 // CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
@@ -76,7 +74,7 @@ half test_refract_half(half I, half N, half ETA) {
 // CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT_I]], <2 x half> poison, <2 x i32> zeroinitializer
 // CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <2 x half> [[ETA]], <2 x half> poison, <2 x i32> zeroinitializer
 // CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[SPLAT_SPLAT6_I]], [[I]]
-// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[CAST_VTRUNC]]
+// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
 // CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <2 x half> poison, half [[MUL_9_I]], i64 0
 // CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT10_I]], <2 x half> poison, <2 x i32> zeroinitializer
 // CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> [[SPLAT_SPLAT_I]])
@@ -90,14 +88,13 @@ half test_refract_half(half I, half N, half ETA) {
 
 //
 // SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_(
-// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // SPVCHECK-NEXT:  [[ENTRY:.*:]]
-// SPVCHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
-// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]] to double
+// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[ETA]] to double
 // SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f64(<2 x half> [[I]], <2 x half> [[N]], double [[CONV_I]])
 // SPVCHECK-NEXT:    ret <2 x half> [[SPV_REFRACT_I]]
 //
-half2 test_refract_half2(half2 I, half2 N, half2 ETA) {
+half2 test_refract_half2(half2 I, half2 N, half ETA) {
     return refract(I, N, ETA);
 }
 
diff --git a/clang/test/CodeGenSPIRV/Builtins/refract.c b/clang/test/CodeGenSPIRV/Builtins/refract.c
index a867b20ae123a..82f620fa293de 100644
--- a/clang/test/CodeGenSPIRV/Builtins/refract.c
+++ b/clang/test/CodeGenSPIRV/Builtins/refract.c
@@ -1,6 +1,6 @@
 // NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
 
-// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s
+// RUN: %clang_cc1 -O1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s
 
 typedef float float2 __attribute__((ext_vector_type(2)));
 typedef float float3 __attribute__((ext_vector_type(3)));
diff --git a/llvm/test/CodeGen/SPIRV/opencl/refract-error.ll b/llvm/test/CodeGen/SPIRV/opencl/refract-error.ll
new file mode 100644
index 0000000000000..28208fb2e72f8
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/opencl/refract-error.ll
@@ -0,0 +1,12 @@
+; RUN: not llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o /dev/null 2>&1 | FileCheck %s
+; RUN: not llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o /dev/null 2>&1 | FileCheck %s
+
+; CHECK: LLVM ERROR: %{{.*}} = G_INTRINSIC intrinsic(@llvm.spv.refract), %{{.*}}, %{{.*}}, %{{.*}} is only supported with the GLSL extended instruction set.
+
+define noundef <4 x float> @refract_float4(<4 x float> noundef %I, <4 x float> noundef %N, float noundef %ETA) {
+entry:
+  %spv.refract = call <4 x float> @llvm.spv.refract.f32(<4 x float> %I, <4 x float> %N, float %ETA)
+  ret <4 x float> %spv.refract
+}
+
+declare <4 x float> @llvm.spv.refract.f32(<4 x float>, <4 x float>, float)

>From 478ef948c7834bb5b19bf353cd518c185a05e81f Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Wed, 14 May 2025 13:52:08 -0700
Subject: [PATCH 08/23]  clang-format

---
 .../lib/Headers/hlsl/hlsl_intrinsic_helpers.h  |  1 -
 clang/lib/Sema/SemaSPIRV.cpp                   | 18 ++++++++++--------
 2 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 5d7cd54a41422..864bc81fd3e63 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -88,7 +88,6 @@ constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
 #endif
 }
 
-
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index d8833161e0673..4b87ad737f270 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -29,13 +29,14 @@ namespace clang {
 
 SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {}
 
-/// Checks if the first `NumArgsToCheck` arguments of a function call are of vector type.
-/// If any of the arguments is not a vector type, it emits a diagnostic error and returns `true`.
-/// Otherwise, it returns `false`.
+/// Checks if the first `NumArgsToCheck` arguments of a function call are of
+/// vector type. If any of the arguments is not a vector type, it emits a
+/// diagnostic error and returns `true`. Otherwise, it returns `false`.
 ///
 /// \param TheCall The function call expression to check.
 /// \param NumArgsToCheck The number of arguments to check for vector type.
-/// \return `true` if any of the arguments is not a vector type, `false` otherwise.
+/// \return `true` if any of the arguments is not a vector type, `false`
+/// otherwise.
 
 bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
   for (unsigned i = 0; i < NumArgsToCheck; ++i) {
@@ -187,7 +188,8 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     if (CheckVectorArgs(TheCall, 2))
       return true;
 
-    QualType RetTy = TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
+    QualType RetTy =
+        TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
     TheCall->setType(RetTy);
     break;
   }
@@ -199,7 +201,8 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     if (CheckVectorArgs(TheCall, 1))
       return true;
 
-    QualType RetTy = TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
+    QualType RetTy =
+        TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
     TheCall->setType(RetTy);
     break;
   }
@@ -215,8 +218,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     QualType ArgTyC = C.get()->getType();
     if (!ArgTyC->isFloatingType()) {
       SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
-          << 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1
-          << ArgTyC;
+          << 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1 << ArgTyC;
       return true;
     }
 

>From 76b6c7f46d7d92f3908aba0c2c7f3a372d342184 Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Wed, 21 May 2025 17:23:33 -0700
Subject: [PATCH 09/23] Remove O1 flag from refract test

---
 clang/test/CodeGenHLSL/builtins/refract.hlsl | 507 ++++++++-----------
 1 file changed, 222 insertions(+), 285 deletions(-)

diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
index e396e372c9aa1..22bda78c82b4a 100644
--- a/clang/test/CodeGenHLSL/builtins/refract.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -1,352 +1,289 @@
 // RUN: %clang_cc1 -finclude-default-header -triple \
 // RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
-// RUN:   -emit-llvm -O1 -o - | FileCheck %s
+// RUN:   -emit-llvm -o - | FileCheck %s
 // RUN: %clang_cc1 -finclude-default-header -triple \
 // RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
-// RUN:   -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
+// RUN:   -emit-llvm -o - | FileCheck %s --check-prefix=SPVCHECK
 
 // CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
-// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
-// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
-// CHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
-// CHECK-NEXT:    [[SUB1_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
-// CHECK-NEXT:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB1_I]]
-// CHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I]]
-// CHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
-// CHECK-NEXT:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
-// CHECK:  if.else.i:                                        ; preds = %entry
-// CHECK-NEXT:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
-// CHECK-NEXT:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
-// CHECK-NEXT:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I]], [[ETA]]
-// CHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
-// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL8_I]]
-// CHECK-NEXT:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
-// CHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[MUL9_I]]
-// CHECK-NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
-// CHECK:  _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
-// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
-// CHECK-NEXT:    ret half [[RETVAL_0_I]]
-//
+// CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK:  [[ENTRY:.*:]]
+// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// CHECK:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// CHECK:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL1_I]], %{{.*}}
+// CHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL2_I]], %{{.*}}
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL3_I]]
+// CHECK:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I]]
+// CHECK:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// CHECK:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// CHECK:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I]], %{{.*}}
+// CHECK:    [[TMP2:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[MUL8_I]], [[TMP2]]
+// CHECK:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[MUL9_I]]
+// CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half %{{.*}}, 0xH0000
+// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CMP_I]], half 0xH0000, half %{{.*}}
+// CHECK:    ret half [[HLSL_SELECT_I]]
 //
 // SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
-// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] 
-// SPVCHECK-NEXT:  [[ENTRY:.*:]]
-// SPVCHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
-// SPVCHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
-// SPVCHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[TMP0]]
-// SPVCHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1]]
-// SPVCHECK-NEXT:    [[MUL_4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
-// SPVCHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_4_I]]
-// SPVCHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I]], 0xH0000
-// SPVCHECK-NEXT:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
-// SPVCHECK:  if.else.i:                                        ; preds = %entry
-// SPVCHECK-NEXT:    [[MUL_6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[I]]
-// SPVCHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[I]]
-// SPVCHECK-NEXT:    [[MUL_8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_7_I]], [[ETA]]
-// SPVCHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I]])
-// SPVCHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[TMP2]], [[MUL_8_I]]
-// SPVCHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], [[N]]
-// SPVCHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL_6_I]], [[MUL_9_I]]
-// SPVCHECK-NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
-// SPVCHECK:  _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
-// SPVCHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I]], %if.else.i ], [ 0xH0000, %entry ]
-// SPVCHECK-NEXT:    ret half [[RETVAL_0_I]]
+// SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] 
+// SPVCHECK:  [[ENTRY:.*:]]
+// SPVCHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// SPVCHECK:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// SPVCHECK:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL1_I]], %{{.*}}
+// SPVCHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL2_I]], %{{.*}}
+// SPVCHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL3_I]]
+// SPVCHECK:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// SPVCHECK:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I]]
+// SPVCHECK:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// SPVCHECK:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// SPVCHECK:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I]], %{{.*}}
+// SPVCHECK:    [[TMP2:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half %{{.*}})
+// SPVCHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[MUL8_I]], [[TMP2]]
+// SPVCHECK:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], %{{.*}}
+// SPVCHECK:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[MUL9_I]]
+// SPVCHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half %{{.*}}, 0xH0000
+// SPVCHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CMP_I]], half 0xH0000, half %{{.*}}
+// SPVCHECK:    ret half [[HLSL_SELECT_I]]
 //
 half test_refract_half(half I, half N, half ETA) {
     return refract(I, N, ETA);
 }
 
 // CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_Dh(
-// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]], [[ETA]]
-// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> [[N]], <2 x half> [[I]])
-// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
-// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
-// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
-// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
-// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
-// CHECK:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
-// CHECK:   if.else.i:                                        ; preds = %entry
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <2 x half> poison, half [[SUB4_I]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT_I]], <2 x half> poison, <2 x i32> zeroinitializer
-// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <2 x half> [[ETA]], <2 x half> poison, <2 x i32> zeroinitializer
-// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[SPLAT_SPLAT6_I]], [[I]]
-// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <2 x half> poison, half [[MUL_9_I]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT10_I]], <2 x half> poison, <2 x i32> zeroinitializer
-// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> [[SPLAT_SPLAT_I]])
-// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
-// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[ADD_I]], [[N]]
-// CHECK-NEXT:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[MUL_7_I]], [[MUL_12_I]]
-// CHECK-NEXT:    br label %_ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit
-// CHECK:  _ZN4hlsl8__detail16refract_vec_implIDhLi2EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
-// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <2 x half> [ %sub13.i, %if.else.i ], [ zeroinitializer, %entry ]
-// CHECK-NEXT:    ret <2 x half> [[RETVAL_0_I]]
-
+// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK:  [[ENTRY:.*:]]
+// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// CHECK:    [[HLSL_DOT_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
+// CHECK:    [[HLSL_DOT1_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
+// CHECK:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT1_I]]
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> %{{.*}}, %{{.*}}
+// CHECK:    [[HLSL_DOT8_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
+// CHECK:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, [[HLSL_DOT8_I]]
+// CHECK:    [[TMP0:%.*]] = call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x half> %{{.*}}, [[TMP0]]
+// CHECK:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt  <2 x half> %{{.*}}, zeroinitializer
+// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 %{{.*}}, <2 x half> zeroinitializer, <2 x half> %{{.*}}
+// CHECK:    ret <2 x half> [[HLSL_SELECT_I]]
 //
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_(
-// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// SPVCHECK-NEXT:  [[ENTRY:.*:]]
-// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[ETA]] to double
-// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f64(<2 x half> [[I]], <2 x half> [[N]], double [[CONV_I]])
-// SPVCHECK-NEXT:    ret <2 x half> [[SPV_REFRACT_I]]
+// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_Dh(
+// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
+// SPVCHECK:  [[ENTRY:.*:]]
+// SPVCHECK:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half %{{.*}} to double
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f64(<2 x half> %{{.*}}, <2 x half> %{{.*}}, double [[CONV_I]])
+// SPVCHECK:    ret <2 x half> [[SPV_REFRACT_I]]
 //
 half2 test_refract_half2(half2 I, half2 N, half ETA) {
     return refract(I, N, ETA);
 }
 
 // CHECK-LABEL: define noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
-// CHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
-// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> [[N]], <3 x half> [[I]])
-// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
-// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
-// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
-// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
-// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
-// CHECK:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
-// CHECK:   if.else.i:                                        ; preds = %entry
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <3 x half> poison, half [[SUB4_I]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT_I]], <3 x half> poison, <3 x i32> zeroinitializer
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <3 x half> poison, half [[ETA]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT5_I]], <3 x half> poison, <3 x i32> zeroinitializer
-// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[SPLAT_SPLAT6_I]], [[I]]
-// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <3 x half> poison, half [[MUL_9_I]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <3 x half> [[SPLAT_SPLATINSERT10_I]], <3 x half> poison, <3 x i32> zeroinitializer
-// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.sqrt.v3f16(<3 x half> [[SPLAT_SPLAT_I]])
-// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
-// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[ADD_I]], [[N]]
-// CHECK-NEXT:   [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[MUL_7_I]], [[MUL_12_I]]
-// CHECK:    br label %_ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit
-// CHECK:  _ZN4hlsl8__detail16refract_vec_implIDhLi3EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
-// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <3 x half> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
-// CHECK-NEXT:    ret <3 x half> [[RETVAL_0_I]]
+// CHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK:  [[ENTRY:.*:]]
+// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// CHECK:    [[HLSL_DOT_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
+// CHECK:    [[HLSL_DOT1_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
+// CHECK:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT1_I]]
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> %{{.*}}, %{{.*}}
+// CHECK:    [[HLSL_DOT8_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
+// CHECK:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, [[HLSL_DOT8_I]]
+// CHECK:    [[TMP0:%.*]] = call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.sqrt.v3f16(<3 x half> %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x half> %{{.*}}, [[TMP0]]
+// CHECK:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt  <3 x half> %{{.*}}, zeroinitializer
+// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 %{{.*}}, <3 x half> zeroinitializer, <3 x half> %{{.*}}
+// CHECK:    ret <3 x half> [[HLSL_SELECT_I]]
 //
 // SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
-// SPVCHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// SPVCHECK-NEXT:  [[ENTRY:.*:]]
-// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[ETA]] to double
-// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <3 x half> @llvm.spv.refract.v3f16.f64(<3 x half> [[I]], <3 x half> [[N]], double [[CONV_I]])
-// SPVCHECK-NEXT:    ret <3 x half> [[SPV_REFRACT_I]]
+// SPVCHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
+// SPVCHECK:  [[ENTRY:.*:]]
+// SPVCHECK:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half %{{.*}} to double
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <3 x half> @llvm.spv.refract.v3f16.f64(<3 x half> %{{.*}}, <3 x half> %{{.*}}, double [[CONV_I]])
+// SPVCHECK:    ret <3 x half> [[SPV_REFRACT_I]]
 //
 half3 test_refract_half3(half3 I, half3 N, half ETA) {
     return refract(I, N, ETA);
 }
 
 // CHECK-LABEL: define noundef nofpclass(nan inf) <4 x half> @_Z18test_refract_half4Dv4_DhS_Dh(
-// CHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[I:%.*]], <4 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ETA]], [[ETA]]
-// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v4f16(<4 x half> [[N]], <4 x half> [[I]])
-// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
-// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
-// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
-// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
-// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB4_I]], 0xH0000
-// CHECK:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIDhLi4EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
-// CHECK:   if.else.i:                                        ; preds = %entry
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <4 x half> poison, half [[SUB4_I]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <4 x half> [[SPLAT_SPLATINSERT_I]], <4 x half> poison, <4 x i32> zeroinitializer
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <4 x half> poison, half [[ETA]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <4 x half> [[SPLAT_SPLATINSERT5_I]], <4 x half> poison, <4 x i32> zeroinitializer
-// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x half> [[SPLAT_SPLAT6_I]], [[I]]
-// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <4 x half> poison, half [[MUL_9_I]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <4 x half> [[SPLAT_SPLATINSERT10_I]], <4 x half> poison, <4 x i32> zeroinitializer
-// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <4 x half> @llvm.sqrt.v4f16(<4 x half> [[SPLAT_SPLAT_I]])
-// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <4 x half> [[TMP0]], [[SPLAT_SPLAT11_I]]
-// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x half> [[ADD_I]], [[N]]
-// CHECK-NEXT:   [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x half> [[MUL_7_I]], [[MUL_12_I]]
-// CHECK:    br label %_ZN4hlsl8__detail16refract_vec_implIDhLi4EEEDvT0__T_S3_S3_S2_.exit
-// CHECK:  _ZN4hlsl8__detail16refract_vec_implIDhLi4EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
-// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <4 x half> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
-// CHECK-NEXT:    ret <4 x half> [[RETVAL_0_I]]
-  
+// CHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[I:%.*]], <4 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK:  [[ENTRY:.*:]]
+// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// CHECK:    [[HLSL_DOT_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
+// CHECK:    [[HLSL_DOT1_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
+// CHECK:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT1_I]]
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
+// CHECK:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
+// CHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
+// CHECK:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x half> %{{.*}}, %{{.*}}
+// CHECK:    [[HLSL_DOT8_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
+// CHECK:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, [[HLSL_DOT8_I]]
+// CHECK:    [[TMP0:%.*]] = call reassoc nnan ninf nsz arcp afn <4 x half> @llvm.sqrt.v4f16(<4 x half> %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <4 x half> %{{.*}}, [[TMP0]]
+// CHECK:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x half> [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt  <4 x half> %{{.*}}, zeroinitializer
+// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 %{{.*}}, <4 x half> zeroinitializer, <4 x half> %{{.*}}
+// CHECK:    ret <4 x half> [[HLSL_SELECT_I]] 
 //
 // SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <4 x half> @_Z18test_refract_half4Dv4_DhS_Dh(
-// SPVCHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[I:%.*]], <4 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// SPVCHECK-NEXT:  [[ENTRY:.*:]]
-// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[ETA]] to double
-// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <4 x half> @llvm.spv.refract.v4f16.f64(<4 x half> [[I]], <4 x half> [[N]], double [[CONV_I]])
-// SPVCHECK-NEXT:    ret <4 x half> [[SPV_REFRACT_I]]
+// SPVCHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[I:%.*]], <4 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
+// SPVCHECK:  [[ENTRY:.*:]]
+// SPVCHECK:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half %{{.*}} to double
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <4 x half> @llvm.spv.refract.v4f16.f64(<4 x half> %{{.*}}, <4 x half> %{{.*}}, double [[CONV_I]])
+// SPVCHECK:    ret <4 x half> [[SPV_REFRACT_I]]
 //
 half4 test_refract_half4(half4 I, half4 N, half ETA) {
     return refract(I, N, ETA);
 }
 
 // CHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_refract_floatfff(
-// CHECK-SAME: float noundef nofpclass(nan inf) [[I:%.*]], float noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ETA]], [[ETA]]
-// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[N]], [[I]]
-// CHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[TMP0]], [[TMP0]]
-// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[TMP1]]
-// CHECK-NEXT:    [[MUL_4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
-// CHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_4_I]]
-// CHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float [[SUB5_I]], 0.000000e+00
-// CHECK:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIfEET_S2_S2_S2_.exit, label %if.else.i
-// CHECK:   if.else.i:                                        ; preds = %entry
-// CHECK-NEXT:    [[MUL_6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ETA]], [[I]]
-// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[N]], [[I]]
-// CHECK-NEXT:    [[MUL_8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_7_I]], [[ETA]]
-// CHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.sqrt.f32(float [[SUB5_I]])
-// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn float [[TMP2]], [[MUL_8_I]]
-// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ADD_I]], [[N]]
-// CHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float [[MUL_6_I]], [[MUL_9_I]]
-// CHECK:    br label %_ZN4hlsl8__detail12refract_implIfEET_S2_S2_S2_.exit
-// CHECK: _ZN4hlsl8__detail12refract_implIfEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
-// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz float [ [[SUB10_I]], %if.else.i ], [ 0.000000e+00, %entry ]
-// CHECK-NEXT:    ret float [[RETVAL_0_I]]
+// CHECK-SAME: float noundef nofpclass(nan inf) [[I:%.*]], float noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK:  [[ENTRY:.*:]]
+// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// CHECK:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// CHECK:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL1_I]], %{{.*}}
+// CHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL2_I]], %{{.*}}
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL3_I]]
+// CHECK:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
+// CHECK:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL4_I]]
+// CHECK:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// CHECK:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// CHECK:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL7_I]], %{{.*}}
+// CHECK:    [[TMP2:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.sqrt.f32(float %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn float [[MUL8_I]], [[TMP2]]
+// CHECK:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float [[MUL6_I]], [[MUL9_I]]
+// CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float %{{.*}}, 0.000000e+00
+// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CMP_I]], float 0.000000e+00, float %{{.*}}
+// CHECK:    ret float [[HLSL_SELECT_I]]
 //
 // SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_refract_floatfff(
-// SPVCHECK-SAME: float noundef nofpclass(nan inf) [[I:%.*]], float noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// SPVCHECK-NEXT:  [[ENTRY:.*:]]
-// SPVCHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ETA]], [[ETA]]
-// SPVCHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[N]], [[I]]
-// SPVCHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[TMP0]], [[TMP0]]
-// SPVCHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[TMP1]]
-// SPVCHECK-NEXT:    [[MUL_4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
-// SPVCHECK-NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_4_I]]
-// SPVCHECK-NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float [[SUB5_I]], 0.000000e+00
-// SPVCHECK-NEXT:    br i1 [[CMP_I]], label %_ZN4hlsl8__detail12refract_implIfEET_S2_S2_S2_.exit, label %if.else.i
-
-// SPVCHECK:  if.else.i:                                        ; preds = %entry
-// SPVCHECK-NEXT:    [[MUL_6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ETA]], [[I]]
-// SPVCHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[N]], [[I]]
-// SPVCHECK-NEXT:    [[MUL_8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_7_I]], [[ETA]]
-// SPVCHECK-NEXT:    [[TMP2:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.sqrt.f32(float [[SUB5_I]])
-// SPVCHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn float [[TMP2]], [[MUL_8_I]]
-// SPVCHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ADD_I]], [[N]]
-// SPVCHECK-NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float [[MUL_6_I]], [[MUL_9_I]]
-// SPVCHECK-NEXT:    br label %_ZN4hlsl8__detail12refract_implIfEET_S2_S2_S2_.exit
-// SPVCHECK:  _ZN4hlsl8__detail12refract_implIfEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
-// SPVCHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz float [ [[SUB10_I]], %if.else.i ], [ 0.000000e+00, %entry ]
-// SPVCHECK-NEXT:    ret float [[RETVAL_0_I]]
+// SPVCHECK-SAME: float noundef nofpclass(nan inf) [[I:%.*]], float noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
+// SPVCHECK:  [[ENTRY:.*:]]
+// SPVCHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// SPVCHECK:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// SPVCHECK:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL1_I]], %{{.*}}
+// SPVCHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL2_I]], %{{.*}}
+// SPVCHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL3_I]]
+// SPVCHECK:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
+// SPVCHECK:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL4_I]]
+// SPVCHECK:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// SPVCHECK:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// SPVCHECK:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL7_I]], %{{.*}}
+// SPVCHECK:    [[TMP2:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.sqrt.f32(float %{{.*}})
+// SPVCHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn float [[MUL8_I]], [[TMP2]]
+// SPVCHECK:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ADD_I]], %{{.*}}
+// SPVCHECK:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float [[MUL6_I]], [[MUL9_I]]
+// SPVCHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float %{{.*}}, 0.000000e+00
+// SPVCHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CMP_I]], float 0.000000e+00, float %{{.*}}
+// SPVCHECK:    ret float [[HLSL_SELECT_I]]
 //
 float test_refract_float(float I, float N, float ETA) {
     return refract(I, N, ETA);
 }
 
 // CHECK-LABEL: define noundef nofpclass(nan inf) <2 x float> @_Z19test_refract_float2Dv2_fS_f(
-// CHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[I:%.*]], <2 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ETA]], [[ETA]]
-// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v2f32(<2 x float> [[N]], <2 x float> [[I]])
-// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[HLSL_DOT_I]]
-// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_2_I]]
-// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
-// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_3_I]]
-// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float [[SUB4_I]], 0.000000e+00
-// CHECK-NEXT:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIfLi2EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
-// CHECK:   if.else.i:                                        ; preds = %entry
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <2 x float> poison, float [[SUB4_I]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT_I]], <2 x float> poison, <2 x i32> zeroinitializer
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <2 x float> poison, float [[ETA]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT5_I]], <2 x float> poison, <2 x i32> zeroinitializer
-// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> [[SPLAT_SPLAT6_I]], [[I]]
-// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[ETA]]
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <2 x float> poison, float [[MUL_9_I]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <2 x float> [[SPLAT_SPLATINSERT10_I]], <2 x float> poison, <2 x i32> zeroinitializer
-// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <2 x float> @llvm.sqrt.v2f32(<2 x float> [[SPLAT_SPLAT_I]])
-// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x float> [[TMP0]], [[SPLAT_SPLAT11_I]]
-// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> [[ADD_I]], [[N]]
-// CHECK-NEXT:   [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x float> [[MUL_7_I]], [[MUL_12_I]]
-// CHECK-NEXT:    br label %_ZN4hlsl8__detail16refract_vec_implIfLi2EEEDvT0__T_S3_S3_S2_.exit
-// CHECK:   _ZN4hlsl8__detail16refract_vec_implIfLi2EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
-// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <2 x float> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
-// CHECK-NEXT:    ret <2 x float> [[RETVAL_0_I]]
+// CHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[I:%.*]], <2 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK:  [[ENTRY:.*:]]
+// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// CHECK:    [[HLSL_DOT_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}})
+// CHECK:    [[HLSL_DOT1_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}})
+// CHECK:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[HLSL_DOT1_I]]
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_2_I]]
+// CHECK:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
+// CHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_3_I]]
+// CHECK:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}}, %{{.*}}
+// CHECK:    [[HLSL_DOT8_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}})
+// CHECK:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, [[HLSL_DOT8_I]]
+// CHECK:    [[TMP0:%.*]] = call reassoc nnan ninf nsz arcp afn <2 x float> @llvm.sqrt.v2f32(<2 x float> %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}}, [[TMP0]]
+// CHECK:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x float> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt  <2 x float> %{{.*}}, zeroinitializer
+// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 %{{.*}}, <2 x float> zeroinitializer, <2 x float> %{{.*}}
+// CHECK:    ret <2 x float> [[HLSL_SELECT_I]]
 //
 // SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x float> @_Z19test_refract_float2Dv2_fS_f(
-// SPVCHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[I:%.*]], <2 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// SPVCHECK-NEXT:  [[ENTRY:.*:]]
-// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn float [[ETA]] to double
-// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x float> @llvm.spv.refract.v2f32.f64(<2 x float> [[I]], <2 x float> [[N]], double [[CONV_I]])
-// SPVCHECK-NEXT:    ret <2 x float> [[SPV_REFRACT_I]]
+// SPVCHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[I:%.*]], <2 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
+// SPVCHECK:  [[ENTRY:.*:]]
+// SPVCHECK:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn float %{{.*}} to double
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <2 x float> @llvm.spv.refract.v2f32.f64(<2 x float> %{{.*}}, <2 x float> %{{.*}}, double [[CONV_I]])
+// SPVCHECK:    ret <2 x float> [[SPV_REFRACT_I]]
 //
 float2 test_refract_float2(float2 I, float2 N, float ETA) {
     return refract(I, N, ETA);
 }
 
 // CHECK-LABEL: define noundef nofpclass(nan inf) <3 x float> @_Z19test_refract_float3Dv3_fS_f(
-// CHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[I:%.*]], <3 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[I:%.*]], <3 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ETA]], [[ETA]]
-// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v3f32(<3 x float> [[N]], <3 x float> [[I]])
-// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[HLSL_DOT_I]]
-// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_2_I]]
-// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
-// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_3_I]]
-// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float [[SUB4_I]], 0.000000e+00
-// CHECK-NEXT:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIfLi3EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
-// CHECK:   if.else.i:                                        ; preds = %entry
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <3 x float> poison, float [[SUB4_I]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT_I]], <3 x float> poison, <3 x i32> zeroinitializer
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <3 x float> poison, float [[ETA]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT5_I]], <3 x float> poison, <3 x i32> zeroinitializer
-// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> [[SPLAT_SPLAT6_I]], [[I]]
-// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[ETA]]
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <3 x float> poison, float [[MUL_9_I]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <3 x float> [[SPLAT_SPLATINSERT10_I]], <3 x float> poison, <3 x i32> zeroinitializer
-// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.sqrt.v3f32(<3 x float> [[SPLAT_SPLAT_I]])
-// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x float> [[TMP0]], [[SPLAT_SPLAT11_I]]
-// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> [[ADD_I]], [[N]]
-// CHECK-NEXT:   [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x float> [[MUL_7_I]], [[MUL_12_I]]
-// CHECK-NEXT:    br label %_ZN4hlsl8__detail16refract_vec_implIfLi3EEEDvT0__T_S3_S3_S2_.exit
-// CHECK:   _ZN4hlsl8__detail16refract_vec_implIfLi3EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
-// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <3 x float> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
-// CHECK-NEXT:    ret <3 x float> [[RETVAL_0_I]]
+// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// CHECK:    [[HLSL_DOT_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}})
+// CHECK:    [[HLSL_DOT1_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}})
+// CHECK:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[HLSL_DOT1_I]]
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_2_I]]
+// CHECK:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
+// CHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_3_I]]
+// CHECK:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> %{{.*}}, %{{.*}}
+// CHECK:    [[HLSL_DOT8_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}})
+// CHECK:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, [[HLSL_DOT8_I]]
+// CHECK:    [[TMP0:%.*]] = call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.sqrt.v3f32(<3 x float> %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x float> %{{.*}}, [[TMP0]]
+// CHECK:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x float> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt  <3 x float> %{{.*}}, zeroinitializer
+// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 %{{.*}}, <3 x float> zeroinitializer, <3 x float> %{{.*}}
+// CHECK:    ret <3 x float> [[HLSL_SELECT_I]]
 //
 // SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <3 x float> @_Z19test_refract_float3Dv3_fS_f(
-// SPVCHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[I:%.*]], <3 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// SPVCHECK-NEXT:  [[ENTRY:.*:]]
-// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn float [[ETA]] to double
-// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <3 x float> @llvm.spv.refract.v3f32.f64(<3 x float> [[I]], <3 x float> [[N]], double [[CONV_I]])
-// SPVCHECK-NEXT:    ret <3 x float> [[SPV_REFRACT_I]]
+// SPVCHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[I:%.*]], <3 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
+// SPVCHECK:  [[ENTRY:.*:]]
+// SPVCHECK:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn float %{{.*}} to double
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <3 x float> @llvm.spv.refract.v3f32.f64(<3 x float> %{{.*}}, <3 x float> %{{.*}}, double [[CONV_I]])
+// SPVCHECK:    ret <3 x float> [[SPV_REFRACT_I]]
 //
 float3 test_refract_float3(float3 I, float3 N, float ETA) {
     return refract(I, N, ETA);
 }
 
 // CHECK-LABEL: define noundef nofpclass(nan inf) <4 x float> @_Z19test_refract_float4Dv4_fS_f
-// CHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[I:%.*]], <4 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// CHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[I:%.*]], <4 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ETA]], [[ETA]]
-// CHECK-NEXT:    [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v4f32(<4 x float> [[N]], <4 x float> [[I]])
-// CHECK-NEXT:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[HLSL_DOT_I]]
-// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_2_I]]
-// CHECK-NEXT:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
-// CHECK-NEXT:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_3_I]]
-// CHECK-NEXT:    [[CAST_VTRUNC_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float [[SUB4_I]], 0.000000e+00
-// CHECK-NEXT:    br i1 [[CAST_VTRUNC_I]], label %_ZN4hlsl8__detail16refract_vec_implIfLi4EEEDvT0__T_S3_S3_S2_.exit, label %if.else.i
-// CHECK:   if.else.i:                                        ; preds = %entry
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT_I:%.*]] = insertelement <4 x float> poison, float [[SUB4_I]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT_I:%.*]] = shufflevector <4 x float> [[SPLAT_SPLATINSERT_I]], <4 x float> poison, <4 x i32> zeroinitializer
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT5_I:%.*]] = insertelement <4 x float> poison, float [[ETA]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT6_I:%.*]] = shufflevector <4 x float> [[SPLAT_SPLATINSERT5_I]], <4 x float> poison, <4 x i32> zeroinitializer
-// CHECK-NEXT:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> [[SPLAT_SPLAT6_I]], [[I]]
-// CHECK-NEXT:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[ETA]]
-// CHECK-NEXT:    [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <4 x float> poison, float [[MUL_9_I]], i64 0
-// CHECK-NEXT:    [[SPLAT_SPLAT11_I:%.*]] = shufflevector <4 x float> [[SPLAT_SPLATINSERT10_I]], <4 x float> poison, <4 x i32> zeroinitializer
-// CHECK-NEXT:    [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <4 x float> @llvm.sqrt.v4f32(<4 x float> [[SPLAT_SPLAT_I]])
-// CHECK-NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <4 x float> [[TMP0]], [[SPLAT_SPLAT11_I]]
-// CHECK-NEXT:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> [[ADD_I]], [[N]]
-// CHECK-NEXT:   [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x float> [[MUL_7_I]], [[MUL_12_I]]
-// CHECK-NEXT:    br label %_ZN4hlsl8__detail16refract_vec_implIfLi4EEEDvT0__T_S3_S3_S2_.exit
-// CHECK:   _ZN4hlsl8__detail16refract_vec_implIfLi4EEEDvT0__T_S3_S3_S2_.exit: ; preds = %entry, %if.else.i
-// CHECK-NEXT:    [[RETVAL_0_I:%.*]] = phi nsz <4 x float> [ [[SUB13_I]], %if.else.i ], [ zeroinitializer, %entry ]
-// CHECK-NEXT:    ret <4 x float> [[RETVAL_0_I]]
-//
+// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// CHECK:    [[HLSL_DOT_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}})
+// CHECK:    [[HLSL_DOT1_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}})
+// CHECK:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[HLSL_DOT1_I]]
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_2_I]]
+// CHECK:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
+// CHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_3_I]]
+// CHECK:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> %{{.*}}, %{{.*}}
+// CHECK:    [[HLSL_DOT8_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}})
+// CHECK:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, [[HLSL_DOT8_I]]
+// CHECK:    [[TMP0:%.*]] = call reassoc nnan ninf nsz arcp afn <4 x float> @llvm.sqrt.v4f32(<4 x float> %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <4 x float> %{{.*}}, [[TMP0]]
+// CHECK:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x float> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt  <4 x float> %{{.*}}, zeroinitializer
+// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 %{{.*}}, <4 x float> zeroinitializer, <4 x float> %{{.*}}
+// CHECK:    ret <4 x float> [[HLSL_SELECT_I]]
+
 // SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <4 x float> @_Z19test_refract_float4Dv4_fS_f(
-// SPVCHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[I]], <4 x float> noundef nofpclass(nan inf) [[N]], float noundef nofpclass(nan inf) [[ETA]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
-// SPVCHECK-NEXT:  [[ENTRY:.*:]]
-// SPVCHECK-NEXT:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn float [[ETA]] to double
-// SPVCHECK-NEXT:    [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32.f64(<4 x float> [[I]], <4 x float> [[N]], double [[CONV_I]])
-// SPVCHECK-NEXT:    ret <4 x float> [[SPV_REFRACT_I]]
+// SPVCHECK-SAME: <4 x float> noundef nofpclass(nan inf) %{{.*}}, <4 x float> noundef nofpclass(nan inf) %{{.*}}, float noundef nofpclass(nan inf) %{{.*}}) #[[ATTR0:[0-9]+]] {
+// SPVCHECK:  [[ENTRY:.*:]]
+// SPVCHECK:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn float %{{.*}} to double
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32.f64(<4 x float> %{{.*}}, <4 x float> %{{.*}}, double [[CONV_I]])
+// SPVCHECK:    ret <4 x float> [[SPV_REFRACT_I]]
 //
 float4 test_refract_float4(float4 I, float4 N, float ETA) {
     return refract(I, N, ETA);

>From dbb1dda581f81e97b67dd76582d315f45db413b1 Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Wed, 16 Apr 2025 13:48:06 -0700
Subject: [PATCH 10/23] Adding dxil tests

---
 clang/lib/Sema/SemaSPIRV.cpp                 | 44 ++++++++++++++++++++
 clang/test/CodeGenHLSL/builtins/reflect.hlsl | 28 ++++++++++---
 llvm/lib/IR/IRBuilder.cpp                    | 21 +++++++++-
 3 files changed, 87 insertions(+), 6 deletions(-)

diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 4b87ad737f270..1e086c0738863 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -226,6 +226,50 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     TheCall->setType(RetTy);
     break;
   }
+  case SPIRV::BI__builtin_spirv_refract: {
+    if (SemaRef.checkArgCount(TheCall, 3))
+      return true;
+
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+    auto *VTyA = ArgTyA->getAs<VectorType>();
+    if (VTyA == nullptr) {
+      SemaRef.Diag(A.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyA
+          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+
+    ExprResult B = TheCall->getArg(1);
+    QualType ArgTyB = B.get()->getType();
+    auto *VTyB = ArgTyB->getAs<VectorType>();
+    if (VTyB == nullptr) {
+      SemaRef.Diag(B.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyB
+          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+
+    ExprResult C = TheCall->getArg(2);
+    QualType ArgTyC = C.get()->getType();
+    if (!ArgTyC->hasFloatingRepresentation()) {
+      SemaRef.Diag(C.get()->getBeginLoc(),
+                   diag::err_builtin_invalid_arg_type)
+          << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
+          << ArgTyC;
+      return true;
+    }
+
+    QualType RetTy = ArgTyA;
+    TheCall->setType(RetTy);
+    assert(RetTy == ArgTyA);
+    //assert(ArgTyB == ArgTyA);
+    break;
+  }
   case SPIRV::BI__builtin_spirv_reflect: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/reflect.hlsl b/clang/test/CodeGenHLSL/builtins/reflect.hlsl
index 351ffac1c1283..dc5011eb4b1a7 100644
--- a/clang/test/CodeGenHLSL/builtins/reflect.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/reflect.hlsl
@@ -8,11 +8,29 @@
 // CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_reflect_halfDhDh(
 // CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[I]], 0xH4000
-// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[N]]
-// CHECK-NEXT:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[MUL_I]]
-// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[I]], [[MUL2_I]]
-// CHECK-NEXT:    ret half [[SUB_I]]
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[eta]], [[eta]]
+// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N:%.*]], %I
+// CHECK_NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0:%.*]], [[TMP0:%.*]]
+// CHECK_NEXT:    [[SUB1_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1:%.*]]
+// CHECK_NEXT:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %mul.i, [[SUB1_I:%.*]]
+// CHECK_NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I:%.*]]
+// CHECK_NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I:%.*]], 0xH0000
+// CHECK_NEXT:    br i1 [[CMP_I:%.*]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
+// CHECK_NEXT:  
+// CHECK_NEXT:    if.else.i:                                        ; preds = %entry
+// CHECK_NEXT:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[eta]], %I
+// CHECK_NEXT:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N:%.*]], %I
+// CHECK_NEXT:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I:%.*]], [[eta]]
+// CHECK_NEXT:    %2 = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I:%.*]])
+// CHECK_NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half %2, [[MUL8_I:%.*]]
+// CHECK_NEXT:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I:%.*]], [[N:%.*]]
+// CHECK_NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I:%.*]], [[MUL9_I:%.*]]
+// CHECK_NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
+// CHECK_NEXT:    
+// CHECK_NEXT:    _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
+// CHECK_NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I:%.*]], %if.else.i ], [ 0xH0000, %entry ]
+// CHECK_NEXT:    ret half [[RETVAL_0_I:%.*]]
+  
 //
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) half @_Z17test_reflect_halfDhDh(
 // SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 28037d7ec5616..8b1864e7c6710 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -865,8 +865,27 @@ CallInst *IRBuilderBase::CreateIntrinsic(Type *RetTy, Intrinsic::ID ID,
 
   SmallVector<Type *> ArgTys;
   ArgTys.reserve(Args.size());
-  for (auto &I : Args)
+  int i =0;
+  Type * Ity;
+  Type * Nty;
+  Type * etaty;
+
+  for (auto &I : Args) {
+    if(i ==0)
+      Ity = I->getType();
+    if(i ==1)
+      Nty = I->getType();
+    if(i ==2)
+      etaty = I->getType();
     ArgTys.push_back(I->getType());
+    i++;
+  }
+  //assert(Ity == RetTy);
+  //assert(Nty == RetTy);
+  assert(Nty == Ity);
+
+
+
   FunctionType *FTy = FunctionType::get(RetTy, ArgTys, false);
   SmallVector<Type *> OverloadTys;
   Intrinsic::MatchIntrinsicTypesResult Res =

>From 2f7d9ce4daaa8700b38066ea7881e281061386fc Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Tue, 22 Apr 2025 16:47:25 -0700
Subject: [PATCH 11/23] Update CodeGen checks

---
 clang/test/CodeGenHLSL/builtins/reflect.hlsl | 28 ++++----------------
 1 file changed, 5 insertions(+), 23 deletions(-)

diff --git a/clang/test/CodeGenHLSL/builtins/reflect.hlsl b/clang/test/CodeGenHLSL/builtins/reflect.hlsl
index dc5011eb4b1a7..351ffac1c1283 100644
--- a/clang/test/CodeGenHLSL/builtins/reflect.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/reflect.hlsl
@@ -8,29 +8,11 @@
 // CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_reflect_halfDhDh(
 // CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[eta]], [[eta]]
-// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N:%.*]], %I
-// CHECK_NEXT:    [[TMP1:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0:%.*]], [[TMP0:%.*]]
-// CHECK_NEXT:    [[SUB1_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[TMP1:%.*]]
-// CHECK_NEXT:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %mul.i, [[SUB1_I:%.*]]
-// CHECK_NEXT:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I:%.*]]
-// CHECK_NEXT:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half [[SUB5_I:%.*]], 0xH0000
-// CHECK_NEXT:    br i1 [[CMP_I:%.*]], label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit, label %if.else.i
-// CHECK_NEXT:  
-// CHECK_NEXT:    if.else.i:                                        ; preds = %entry
-// CHECK_NEXT:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[eta]], %I
-// CHECK_NEXT:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N:%.*]], %I
-// CHECK_NEXT:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I:%.*]], [[eta]]
-// CHECK_NEXT:    %2 = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half [[SUB5_I:%.*]])
-// CHECK_NEXT:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half %2, [[MUL8_I:%.*]]
-// CHECK_NEXT:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I:%.*]], [[N:%.*]]
-// CHECK_NEXT:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I:%.*]], [[MUL9_I:%.*]]
-// CHECK_NEXT:    br label %_ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit
-// CHECK_NEXT:    
-// CHECK_NEXT:    _ZN4hlsl8__detail12refract_implIDhEET_S2_S2_S2_.exit: ; preds = %entry, %if.else.i
-// CHECK_NEXT:    [[RETVAL_0_I:%.*]] = phi nsz half [ [[SUB10_I:%.*]], %if.else.i ], [ 0xH0000, %entry ]
-// CHECK_NEXT:    ret half [[RETVAL_0_I:%.*]]
-  
+// CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[I]], 0xH4000
+// CHECK-NEXT:    [[TMP0:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[N]], [[N]]
+// CHECK-NEXT:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[TMP0]], [[MUL_I]]
+// CHECK-NEXT:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[I]], [[MUL2_I]]
+// CHECK-NEXT:    ret half [[SUB_I]]
 //
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) half @_Z17test_reflect_halfDhDh(
 // SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {

>From 76950bd4d534b8048eecdf3f55f7c5b3c01a0cc3 Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Thu, 1 May 2025 17:26:32 -0700
Subject: [PATCH 12/23] Addressing partial review comments

---
 clang/lib/Sema/SemaSPIRV.cpp               | 3 +--
 clang/test/CodeGenSPIRV/Builtins/refract.c | 2 +-
 2 files changed, 2 insertions(+), 3 deletions(-)

diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 1e086c0738863..f564eabaf3fe4 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -257,8 +257,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     ExprResult C = TheCall->getArg(2);
     QualType ArgTyC = C.get()->getType();
     if (!ArgTyC->hasFloatingRepresentation()) {
-      SemaRef.Diag(C.get()->getBeginLoc(),
-                   diag::err_builtin_invalid_arg_type)
+      SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
           << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
           << ArgTyC;
       return true;
diff --git a/clang/test/CodeGenSPIRV/Builtins/refract.c b/clang/test/CodeGenSPIRV/Builtins/refract.c
index 82f620fa293de..a867b20ae123a 100644
--- a/clang/test/CodeGenSPIRV/Builtins/refract.c
+++ b/clang/test/CodeGenSPIRV/Builtins/refract.c
@@ -1,6 +1,6 @@
 // NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
 
-// RUN: %clang_cc1 -O1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s
+// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s
 
 typedef float float2 __attribute__((ext_vector_type(2)));
 typedef float float3 __attribute__((ext_vector_type(3)));

>From 78fcafd881109f3bf2129cef103f03aff5d0680f Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Wed, 14 May 2025 12:41:21 -0700
Subject: [PATCH 13/23] modularize SemaSPIRV.cpp

---
 .../lib/Headers/hlsl/hlsl_intrinsic_helpers.h |  1 +
 clang/lib/Sema/SemaSPIRV.cpp                  | 30 ++++---------------
 clang/test/CodeGenSPIRV/Builtins/refract.c    |  2 +-
 3 files changed, 7 insertions(+), 26 deletions(-)

diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 864bc81fd3e63..5d7cd54a41422 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -88,6 +88,7 @@ constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
 #endif
 }
 
+
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index f564eabaf3fe4..323e12c48d84f 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -230,40 +230,20 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
 
-    ExprResult A = TheCall->getArg(0);
-    QualType ArgTyA = A.get()->getType();
-    auto *VTyA = ArgTyA->getAs<VectorType>();
-    if (VTyA == nullptr) {
-      SemaRef.Diag(A.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyA
-          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
-          << 0 << 0;
-      return true;
-    }
-
-    ExprResult B = TheCall->getArg(1);
-    QualType ArgTyB = B.get()->getType();
-    auto *VTyB = ArgTyB->getAs<VectorType>();
-    if (VTyB == nullptr) {
-      SemaRef.Diag(B.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTyB
-          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
-          << 0 << 0;
+    // Use the helper function to check the first two arguments
+    if (CheckVectorArgs(TheCall, 2))
       return true;
-    }
 
     ExprResult C = TheCall->getArg(2);
     QualType ArgTyC = C.get()->getType();
-    if (!ArgTyC->hasFloatingRepresentation()) {
+    if (!ArgTyC->isFloatingType()) {
       SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
-          << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
+          << 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1
           << ArgTyC;
       return true;
     }
 
-    QualType RetTy = ArgTyA;
+    QualType RetTy = TheCall->getArg(0)->getType();
     TheCall->setType(RetTy);
     assert(RetTy == ArgTyA);
     //assert(ArgTyB == ArgTyA);
diff --git a/clang/test/CodeGenSPIRV/Builtins/refract.c b/clang/test/CodeGenSPIRV/Builtins/refract.c
index a867b20ae123a..82f620fa293de 100644
--- a/clang/test/CodeGenSPIRV/Builtins/refract.c
+++ b/clang/test/CodeGenSPIRV/Builtins/refract.c
@@ -1,6 +1,6 @@
 // NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
 
-// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s
+// RUN: %clang_cc1 -O1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s
 
 typedef float float2 __attribute__((ext_vector_type(2)));
 typedef float float3 __attribute__((ext_vector_type(3)));

>From 2a7720c608de808066443e280b5fdda5064fa298 Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Wed, 14 May 2025 13:52:08 -0700
Subject: [PATCH 14/23]  clang-format

---
 clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h | 1 -
 clang/lib/Sema/SemaSPIRV.cpp                    | 3 +--
 2 files changed, 1 insertion(+), 3 deletions(-)

diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 5d7cd54a41422..864bc81fd3e63 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -88,7 +88,6 @@ constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
 #endif
 }
 
-
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 323e12c48d84f..92fdbe75df423 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -238,8 +238,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     QualType ArgTyC = C.get()->getType();
     if (!ArgTyC->isFloatingType()) {
       SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
-          << 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1
-          << ArgTyC;
+          << 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1 << ArgTyC;
       return true;
     }
 

>From 199bda926f7ec40f394e6a4aba4d9f9907608fa3 Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Wed, 25 Jun 2025 16:28:39 -0700
Subject: [PATCH 15/23] update SPV to customeTypeChecking

---
 clang/include/clang/Basic/BuiltinsSPIRV.td    |  39 ---
 clang/include/clang/Basic/BuiltinsSPIRVVK.td  |   1 +
 clang/include/clang/Sema/Sema.h               |  24 ++
 clang/lib/Headers/hlsl/hlsl_detail.h          |   8 +
 .../lib/Headers/hlsl/hlsl_intrinsic_helpers.h |  29 +-
 clang/lib/Sema/SemaChecking.cpp               | 105 +++++++
 clang/lib/Sema/SemaHLSL.cpp                   |  73 ++++-
 clang/lib/Sema/SemaSPIRV.cpp                  |  62 +---
 clang/test/CodeGenHLSL/builtins/reflect.hlsl  |   5 +-
 clang/test/CodeGenHLSL/builtins/refract.hlsl  | 295 ++++++++----------
 clang/test/CodeGenSPIRV/Builtins/refract.c    |  11 +-
 .../test/SemaSPIRV/BuiltIns/refract-errors.c  |   9 +-
 llvm/lib/IR/IRBuilder.cpp                     |  23 +-
 .../CodeGen/SPIRV/hlsl-intrinsics/refract.ll  |  15 +-
 14 files changed, 390 insertions(+), 309 deletions(-)
 delete mode 100644 clang/include/clang/Basic/BuiltinsSPIRV.td

diff --git a/clang/include/clang/Basic/BuiltinsSPIRV.td b/clang/include/clang/Basic/BuiltinsSPIRV.td
deleted file mode 100644
index c0f652b4f24e4..0000000000000
--- a/clang/include/clang/Basic/BuiltinsSPIRV.td
+++ /dev/null
@@ -1,39 +0,0 @@
-//===--- BuiltinsSPIRV.td - SPIRV Builtin function database ---------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-include "clang/Basic/BuiltinsBase.td"
-
-def SPIRVDistance : Builtin {
-  let Spellings = ["__builtin_spirv_distance"];
-  let Attributes = [NoThrow, Const];
-  let Prototype = "void(...)";
-}
-
-def SPIRVLength : Builtin {
-  let Spellings = ["__builtin_spirv_length"];
-  let Attributes = [NoThrow, Const];
-  let Prototype = "void(...)";
-}
-
-def SPIRVReflect : Builtin {
-  let Spellings = ["__builtin_spirv_reflect"];
-  let Attributes = [NoThrow, Const];
-  let Prototype = "void(...)";
-}
-
-def SPIRVRefract : Builtin {
-  let Spellings = ["__builtin_spirv_refract"];
-  let Attributes = [NoThrow, Const];
-  let Prototype = "void(...)";
-}
-
-def SPIRVSmoothStep : Builtin {
-  let Spellings = ["__builtin_spirv_smoothstep"];
-  let Attributes = [NoThrow, Const, CustomTypeChecking];
-  let Prototype = "void(...)";
-}
diff --git a/clang/include/clang/Basic/BuiltinsSPIRVVK.td b/clang/include/clang/Basic/BuiltinsSPIRVVK.td
index 61cc0343c415e..5dc3c7588cd2a 100644
--- a/clang/include/clang/Basic/BuiltinsSPIRVVK.td
+++ b/clang/include/clang/Basic/BuiltinsSPIRVVK.td
@@ -11,3 +11,4 @@ include "clang/Basic/BuiltinsSPIRVBase.td"
 
 def reflect : SPIRVBuiltin<"void(...)", [NoThrow, Const]>;
 def faceforward : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
+def refract : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 3fe26f950ad51..105ab804fffd0 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -2791,6 +2791,30 @@ class Sema final : public SemaBase {
 
   void CheckConstrainedAuto(const AutoType *AutoT, SourceLocation Loc);
 
+  /// CheckVectorArgs - Check that the arguments of a vector function call
+  bool CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck);
+
+  bool CheckVectorArgs(CallExpr *TheCall);
+
+  bool CheckAllArgTypesAreCorrect(
+      Sema *S, CallExpr *TheCall,
+      llvm::ArrayRef<
+          llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+          Checks);
+  bool CheckAllArgTypesAreCorrect(
+      Sema *S, CallExpr *TheCall,
+      llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check);
+
+  static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+                                            int ArgOrdinal,
+                                            clang::QualType PassedType);
+  static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+                                             int ArgOrdinal,
+                                             clang::QualType PassedType);
+
+  static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
+                                                int ArgOrdinal,
+                                                clang::QualType PassedType);
   /// BuiltinConstantArg - Handle a check if argument ArgNum of CallExpr
   /// TheCall is a constant expression.
   bool BuiltinConstantArg(CallExpr *TheCall, int ArgNum, llvm::APSInt &Result);
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 80c4900121dfb..96e101a1e3aa8 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -45,6 +45,14 @@ template <typename T> struct is_arithmetic {
   static const bool Value = __is_arithmetic(T);
 };
 
+template <typename T> struct is_vector {
+  static const bool value = false;
+};
+
+template <typename T, int N> struct is_vector<vector<T, N>> {
+  static const bool value = true;
+};
+
 template <typename T, int N>
 using HLSL_FIXED_VECTOR =
     vector<__detail::enable_if_t<(N > 1 && N <= 4), T>, N>;
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 864bc81fd3e63..f6acb1cea2594 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -72,22 +72,41 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 }
 
 template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
-  T K = 1 - Eta * Eta * (1 - (N * I * N * I));
-  T Result = (Eta * I - (Eta * N * I + sqrt(K)) * N);
+  T Mul = N * I;
+  T K = 1 - Eta * Eta * (1 - (Mul * Mul));
+  T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
   return select<T>(K < 0, static_cast<T>(0), Result);
 }
 
+template <typename T, typename U>
+constexpr T refract_vec_impl(T I, T N, U Eta) {
+#if (__has_builtin(__builtin_spirv_refract))
+  if (is_vector<T>::value) {
+    return __builtin_spirv_refract(I, N, Eta);
+  }
+#else
+  T Mul = dot(N, I);
+  T K = 1 - Eta * Eta * (1 - Mul * Mul);
+  T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
+  return select<T>(K < 0, static_cast<T>(0), Result);
+#endif
+}
+
+/*
 template <typename T, int L>
 constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
-#if (__has_builtin(__builtin_spirv_refract))
+#if (__has_builtin(__builtin_spirv_refract) && is_vector<T>))
   return __builtin_spirv_refract(I, N, Eta);
 #else
-  vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I));
-  vector<T, L> Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
+  T Mul = dot(N, I);
+  vector<T, L> K = 1 - Eta * Eta * (1 - Mul * Mul);
+  vector<T, L> Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
   return select<vector<T, L>>(K < 0, vector<T, L>(0), Result);
 #endif
 }
 
+*/
+
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index dd5b710d7e1d4..98bca59f14ecd 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -16151,3 +16151,108 @@ void Sema::CheckTCBEnforcement(const SourceLocation CallExprLoc,
     }
   }
 }
+
+bool Sema::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
+  for (unsigned i = 0; i < NumArgsToCheck; ++i) {
+    ExprResult Arg = TheCall->getArg(i);
+    QualType ArgTy = Arg.get()->getType();
+    auto *VTy = ArgTy->getAs<VectorType>();
+    if (VTy == nullptr) {
+      SemaRef.Diag(Arg.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTy
+          << SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+  }
+  return false;
+}
+
+bool Sema::CheckVectorArgs(CallExpr *TheCall) {
+  return CheckVectorArgs(TheCall, TheCall->getNumArgs());
+}
+
+
+bool Sema::CheckAllArgTypesAreCorrect(
+    Sema *S, CallExpr *TheCall,
+    llvm::ArrayRef<
+        llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+        Checks) {
+  unsigned NumArgs = TheCall->getNumArgs();
+  if (Checks.size() == 1) {
+    // Apply the single check to all arguments
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else if (Checks.size() == NumArgs) {
+    // Apply each check to the corresponding argument
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else {
+    // Mismatch: error or fallback
+    S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+        << NumArgs << Checks.size();
+    return true;
+  }
+}
+
+bool Sema::CheckAllArgTypesAreCorrect(
+    Sema *S, CallExpr *TheCall,
+    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
+  return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
+}
+
+bool Sema::CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
+bool Sema::CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+                                                  int ArgOrdinal,
+                                                  clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
+bool Sema::CheckFloatOrHalfScalarRepresentation(
+    Sema *S, SourceLocation Loc,
+                                                 int ArgOrdinal,
+                                                 clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index bad357b50929b..991d330edfb6f 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2401,17 +2401,40 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
   return false;
 }
 
-static bool CheckAllArgTypesAreCorrect(
+bool CheckAllArgTypesAreCorrect(
     Sema *S, CallExpr *TheCall,
-    llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
-                            clang::QualType PassedType)>
-        Check) {
-  for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
-    Expr *Arg = TheCall->getArg(I);
-    if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
-      return true;
+    llvm::ArrayRef<
+        llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+        Checks) {
+  unsigned NumArgs = TheCall->getNumArgs();
+  if (Checks.size() == 1) {
+    // Apply the single check to all arguments
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else if (Checks.size() == NumArgs) {
+    // Apply each check to the corresponding argument
+    for (unsigned I = 0; I < NumArgs; ++I) {
+      Expr *Arg = TheCall->getArg(I);
+      if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+        return true;
+    }
+    return false;
+  } else {
+    // Mismatch: error or fallback
+    S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
+        << NumArgs << Checks.size();
+    return true;
   }
-  return false;
+}
+
+bool CheckAllArgTypesAreCorrect(
+    Sema *S, CallExpr *TheCall,
+    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
+  return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
 }
 
 static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
@@ -2428,6 +2451,38 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
   return false;
 }
 
+static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType = 
+      PassedType->isVectorType()
+        ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
+static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
+                                                 int ArgOrdinal,
+                                                 clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
                                   unsigned ArgIndex) {
   auto *Arg = TheCall->getArg(ArgIndex);
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 92fdbe75df423..86ec569499a5c 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -29,32 +29,6 @@ namespace clang {
 
 SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {}
 
-/// Checks if the first `NumArgsToCheck` arguments of a function call are of
-/// vector type. If any of the arguments is not a vector type, it emits a
-/// diagnostic error and returns `true`. Otherwise, it returns `false`.
-///
-/// \param TheCall The function call expression to check.
-/// \param NumArgsToCheck The number of arguments to check for vector type.
-/// \return `true` if any of the arguments is not a vector type, `false`
-/// otherwise.
-
-bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
-  for (unsigned i = 0; i < NumArgsToCheck; ++i) {
-    ExprResult Arg = TheCall->getArg(i);
-    QualType ArgTy = Arg.get()->getType();
-    auto *VTy = ArgTy->getAs<VectorType>();
-    if (VTy == nullptr) {
-      SemaRef.Diag(Arg.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTy
-          << SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1
-          << 0 << 0;
-      return true;
-    }
-  }
-  return false;
-}
-
 static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
   assert(TheCall->getNumArgs() > 1);
   QualType ArgTy0 = TheCall->getArg(0)->getType();
@@ -71,7 +45,6 @@ static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
   }
   return false;
 }
-bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
 
 static std::optional<int>
 processConstant32BitIntArgument(Sema &SemaRef, CallExpr *Call, int Argument) {
@@ -185,7 +158,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
       return true;
 
     // Use the helper function to check both arguments
-    if (CheckVectorArgs(TheCall, 2))
+    if (SemaRef.CheckVectorArgs(TheCall))
       return true;
 
     QualType RetTy =
@@ -198,7 +171,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
       return true;
 
     // Use the helper function to check the argument
-    if (CheckVectorArgs(TheCall, 1))
+    if (SemaRef.CheckVectorArgs(TheCall))
       return true;
 
     QualType RetTy =
@@ -210,8 +183,12 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
 
-    // Use the helper function to check the first two arguments
-    if (CheckVectorArgs(TheCall, 2))
+    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
+        ChecksArr[] = {Sema::CheckFloatOrHalfVectorsRepresentation,
+                       Sema::CheckFloatOrHalfVectorsRepresentation,
+                       Sema::CheckFloatOrHalfScalarRepresentation};
+    if (SemaRef.CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                           llvm::ArrayRef(ChecksArr)))
       return true;
 
     ExprResult C = TheCall->getArg(2);
@@ -226,34 +203,13 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     TheCall->setType(RetTy);
     break;
   }
-  case SPIRV::BI__builtin_spirv_refract: {
-    if (SemaRef.checkArgCount(TheCall, 3))
-      return true;
-
-    // Use the helper function to check the first two arguments
-    if (CheckVectorArgs(TheCall, 2))
-      return true;
 
-    ExprResult C = TheCall->getArg(2);
-    QualType ArgTyC = C.get()->getType();
-    if (!ArgTyC->isFloatingType()) {
-      SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
-          << 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1 << ArgTyC;
-      return true;
-    }
-
-    QualType RetTy = TheCall->getArg(0)->getType();
-    TheCall->setType(RetTy);
-    assert(RetTy == ArgTyA);
-    //assert(ArgTyB == ArgTyA);
-    break;
-  }
   case SPIRV::BI__builtin_spirv_reflect: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
 
     // Use the helper function to check both arguments
-    if (CheckVectorArgs(TheCall, 2))
+    if (SemaRef.CheckVectorArgs(TheCall))
       return true;
 
     QualType RetTy = TheCall->getArg(0)->getType();
diff --git a/clang/test/CodeGenHLSL/builtins/reflect.hlsl b/clang/test/CodeGenHLSL/builtins/reflect.hlsl
index 351ffac1c1283..2d79ef45dabdc 100644
--- a/clang/test/CodeGenHLSL/builtins/reflect.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/reflect.hlsl
@@ -1,3 +1,4 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
 // RUN: %clang_cc1 -finclude-default-header -triple \
 // RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
 // RUN:   -emit-llvm -O1 -o - | FileCheck %s
@@ -5,7 +6,7 @@
 // RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
 // RUN:   -emit-llvm -O1 -o - | FileCheck %s --check-prefix=SPVCHECK
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_reflect_halfDhDh(
+// CHECK-LABEL: define hidden noundef nofpclass(nan inf) half @_Z17test_reflect_halfDhDh(
 // CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
 // CHECK-NEXT:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[I]], 0xH4000
@@ -173,4 +174,4 @@ float3 test_reflect_float3(float3 I, float3 N) {
 //
 float4 test_reflect_float4(float4 I, float4 N) {
     return reflect(I, N);
-}
+}
\ No newline at end of file
diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
index 22bda78c82b4a..5fab2f988e7c5 100644
--- a/clang/test/CodeGenHLSL/builtins/refract.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -5,44 +5,40 @@
 // RUN:   spirv-unknown-vulkan-compute %s -fnative-half-type \
 // RUN:   -emit-llvm -o - | FileCheck %s --check-prefix=SPVCHECK
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// CHECK-LABEL: define hidden noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
 // CHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK:  [[ENTRY:.*:]]
 // CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
 // CHECK:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
-// CHECK:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL1_I]], %{{.*}}
-// CHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL2_I]], %{{.*}}
-// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL3_I]]
-// CHECK:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
-// CHECK:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I]]
+// CHECK:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL2_I]]
+// CHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL1_I]], [[SUB_I]]
+// CHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL3_I]]
+// CHECK:    [[MUL5_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
 // CHECK:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
-// CHECK:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
-// CHECK:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I]], %{{.*}}
-// CHECK:    [[TMP2:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half %{{.*}})
-// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[MUL8_I]], [[TMP2]]
-// CHECK:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], %{{.*}}
-// CHECK:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[MUL9_I]]
+// CHECK:    [[TMP0:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[MUL6_I]], %{{.*}}
+// CHECK:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB8_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half %{{.*}}, [[MUL7_I]]
 // CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half %{{.*}}, 0xH0000
 // CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CMP_I]], half 0xH0000, half %{{.*}}
 // CHECK:    ret half [[HLSL_SELECT_I]]
 //
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
+// SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
 // SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] 
 // SPVCHECK:  [[ENTRY:.*:]]
 // SPVCHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
 // SPVCHECK:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
-// SPVCHECK:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL1_I]], %{{.*}}
-// SPVCHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL2_I]], %{{.*}}
-// SPVCHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL3_I]]
-// SPVCHECK:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
-// SPVCHECK:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL4_I]]
+// SPVCHECK:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// SPVCHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL2_I]]
+// SPVCHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL1_I]], [[SUB_I]]
+// SPVCHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL3_I]]
+// SPVCHECK:    [[MUL5_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
 // SPVCHECK:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
-// SPVCHECK:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
-// SPVCHECK:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL7_I]], %{{.*}}
-// SPVCHECK:    [[TMP2:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half %{{.*}})
-// SPVCHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[MUL8_I]], [[TMP2]]
-// SPVCHECK:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], %{{.*}}
-// SPVCHECK:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[MUL9_I]]
+// SPVCHECK:    [[TMP18:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half %{{.*}})
+// SPVCHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[TMP18]]
+// SPVCHECK:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], %{{.*}}
+// SPVCHECK:    [[SUB8_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL5_I]], [[MUL7_I]]
 // SPVCHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half %{{.*}}, 0xH0000
 // SPVCHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CMP_I]], half 0xH0000, half %{{.*}}
 // SPVCHECK:    ret half [[HLSL_SELECT_I]]
@@ -51,140 +47,131 @@ half test_refract_half(half I, half N, half ETA) {
     return refract(I, N, ETA);
 }
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_Dh(
+// CHECK-LABEL: define hidden noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_Dh(
 // CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK:  [[ENTRY:.*:]]
-// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
 // CHECK:    [[HLSL_DOT_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
-// CHECK:    [[HLSL_DOT1_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
-// CHECK:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT1_I]]
-// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
-// CHECK:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
-// CHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
-// CHECK:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> %{{.*}}, %{{.*}}
-// CHECK:    [[HLSL_DOT8_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}})
-// CHECK:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, [[HLSL_DOT8_I]]
-// CHECK:    [[TMP0:%.*]] = call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> %{{.*}})
-// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x half> %{{.*}}, [[TMP0]]
-// CHECK:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[ADD_I]], %{{.*}}
-// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// CHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> %{{.*}}, %{{.*}}
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> splat (half 0xH3C00), [[MUL3_I]]
+// CHECK:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> %{{.*}}, [[SUB_I]]
+// CHECK:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> splat (half 0xH3C00), [[MUL4_I]]
+// CHECK:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> %{{.*}}, %{{.*}}
+// CHECK:    [[MUL11_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> %{{.*}}, %{{.*}}
+// CHECK:    [[TMP17:%.*]] = call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x half> [[MUL11_I]], [[TMP17]]
+// CHECK:    [[MUL12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x half> [[MUL8_I]], [[MUL12_I]]
 // CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt  <2 x half> %{{.*}}, zeroinitializer
-// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 %{{.*}}, <2 x half> zeroinitializer, <2 x half> %{{.*}}
+// CHECK:    [[CAST:%.*]] = extractelement <2 x i1> [[CMP_I]], i32 0
+// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CAST]], <2 x half> zeroinitializer, <2 x half> %{{.*}}
 // CHECK:    ret <2 x half> [[HLSL_SELECT_I]]
 //
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_Dh(
+// SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_Dh(
 // SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half %{{.*}} to double
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f64(<2 x half> %{{.*}}, <2 x half> %{{.*}}, double [[CONV_I]])
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, half %{{.*}})
 // SPVCHECK:    ret <2 x half> [[SPV_REFRACT_I]]
 //
 half2 test_refract_half2(half2 I, half2 N, half ETA) {
     return refract(I, N, ETA);
 }
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
+// CHECK-LABEL: define hidden noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
 // CHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK:  [[ENTRY:.*:]]
-// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
 // CHECK:    [[HLSL_DOT_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
-// CHECK:    [[HLSL_DOT1_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
-// CHECK:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT1_I]]
-// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
-// CHECK:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
-// CHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
-// CHECK:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> %{{.*}}, %{{.*}}
-// CHECK:    [[HLSL_DOT8_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}})
-// CHECK:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, [[HLSL_DOT8_I]]
-// CHECK:    [[TMP0:%.*]] = call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.sqrt.v3f16(<3 x half> %{{.*}})
-// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x half> %{{.*}}, [[TMP0]]
-// CHECK:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[ADD_I]], %{{.*}}
-// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// CHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> %{{.*}}, %{{.*}}
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> splat (half 0xH3C00), [[MUL3_I]]
+// CHECK:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> %{{.*}}, [[SUB_I]]
+// CHECK:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> splat (half 0xH3C00), [[MUL4_I]]
+// CHECK:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> %{{.*}}, %{{.*}}
+// CHECK:    [[MUL11_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> %{{.*}}, %{{.*}}
+// CHECK:    [[TMP17:%.*]] = call reassoc nnan ninf nsz arcp afn <3 x half> @llvm.sqrt.v3f16(<3 x half> %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x half> [[MUL11_I]], [[TMP17]]
+// CHECK:    [[MUL12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x half> [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x half> [[MUL8_I]], [[MUL12_I]]
 // CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt  <3 x half> %{{.*}}, zeroinitializer
-// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 %{{.*}}, <3 x half> zeroinitializer, <3 x half> %{{.*}}
+// CHECK:    [[CAST:%.*]] = extractelement <3 x i1> [[CMP_I]], i32 0
+// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CAST]], <3 x half> zeroinitializer, <3 x half> %{{.*}}
 // CHECK:    ret <3 x half> [[HLSL_SELECT_I]]
 //
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
+// SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
 // SPVCHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half %{{.*}} to double
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <3 x half> @llvm.spv.refract.v3f16.f64(<3 x half> %{{.*}}, <3 x half> %{{.*}}, double [[CONV_I]])
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <3 x half> @llvm.spv.refract.v3f16.f16(<3 x half> %{{.*}}, <3 x half> %{{.*}}, half %{{.*}})
 // SPVCHECK:    ret <3 x half> [[SPV_REFRACT_I]]
 //
 half3 test_refract_half3(half3 I, half3 N, half ETA) {
     return refract(I, N, ETA);
 }
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) <4 x half> @_Z18test_refract_half4Dv4_DhS_Dh(
+// CHECK-LABEL: define hidden noundef nofpclass(nan inf) <4 x half> @_Z18test_refract_half4Dv4_DhS_Dh(
 // CHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[I:%.*]], <4 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK:  [[ENTRY:.*:]]
-// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
 // CHECK:    [[HLSL_DOT_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
-// CHECK:    [[HLSL_DOT1_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
-// CHECK:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT1_I]]
-// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
-// CHECK:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL_I]], [[SUB_I]]
-// CHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_3_I]]
-// CHECK:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x half> %{{.*}}, %{{.*}}
-// CHECK:    [[HLSL_DOT8_I:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}})
-// CHECK:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, [[HLSL_DOT8_I]]
-// CHECK:    [[TMP0:%.*]] = call reassoc nnan ninf nsz arcp afn <4 x half> @llvm.sqrt.v4f16(<4 x half> %{{.*}})
-// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <4 x half> %{{.*}}, [[TMP0]]
-// CHECK:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x half> [[ADD_I]], %{{.*}}
-// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x half> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
+// CHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x half> %{{.*}}, %{{.*}}
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x half> splat (half 0xH3C00), [[MUL3_I]]
+// CHECK:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x half> %{{.*}}, [[SUB_I]]
+// CHECK:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x half> splat (half 0xH3C00), [[MUL4_I]]
+// CHECK:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x half> %{{.*}}, %{{.*}}
+// CHECK:    [[MUL11_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x half> %{{.*}}, %{{.*}}
+// CHECK:    [[TMP17:%.*]] = call reassoc nnan ninf nsz arcp afn <4 x half> @llvm.sqrt.v4f16(<4 x half> %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <4 x half> [[MUL11_I]], [[TMP17]]
+// CHECK:    [[MUL12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x half> [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x half> [[MUL8_I]], [[MUL12_I]]
 // CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt  <4 x half> %{{.*}}, zeroinitializer
-// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 %{{.*}}, <4 x half> zeroinitializer, <4 x half> %{{.*}}
-// CHECK:    ret <4 x half> [[HLSL_SELECT_I]] 
+// CHECK:    [[CAST:%.*]] = extractelement <4 x i1> [[CMP_I]], i32 0
+// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CAST]], <4 x half> zeroinitializer, <4 x half> %{{.*}}
+// CHECK:    ret <4 x half> [[HLSL_SELECT_I]]
 //
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <4 x half> @_Z18test_refract_half4Dv4_DhS_Dh(
+// SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <4 x half> @_Z18test_refract_half4Dv4_DhS_Dh(
 // SPVCHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[I:%.*]], <4 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half %{{.*}} to double
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <4 x half> @llvm.spv.refract.v4f16.f64(<4 x half> %{{.*}}, <4 x half> %{{.*}}, double [[CONV_I]])
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <4 x half> @llvm.spv.refract.v4f16.f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, half %{{.*}})
 // SPVCHECK:    ret <4 x half> [[SPV_REFRACT_I]]
 //
 half4 test_refract_half4(half4 I, half4 N, half ETA) {
     return refract(I, N, ETA);
 }
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) float @_Z18test_refract_floatfff(
+// CHECK-LABEL: define hidden noundef nofpclass(nan inf) float @_Z18test_refract_floatfff(
 // CHECK-SAME: float noundef nofpclass(nan inf) [[I:%.*]], float noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK:  [[ENTRY:.*:]]
 // CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
 // CHECK:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
-// CHECK:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL1_I]], %{{.*}}
-// CHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL2_I]], %{{.*}}
-// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL3_I]]
-// CHECK:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
-// CHECK:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL4_I]]
+// CHECK:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL2_I]]
+// CHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL1_I]], [[SUB_I]]
+// CHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL3_I]]
+// CHECK:    [[MUL5_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
 // CHECK:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
-// CHECK:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
-// CHECK:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL7_I]], %{{.*}}
-// CHECK:    [[TMP2:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.sqrt.f32(float %{{.*}})
-// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn float [[MUL8_I]], [[TMP2]]
-// CHECK:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ADD_I]], %{{.*}}
-// CHECK:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float [[MUL6_I]], [[MUL9_I]]
+// CHECK:    [[TMP17:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.sqrt.f32(float %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn float [[MUL6_I]], %{{.*}}
+// CHECK:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB8_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float %{{.*}}, [[MUL7_I]]
 // CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float %{{.*}}, 0.000000e+00
 // CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CMP_I]], float 0.000000e+00, float %{{.*}}
 // CHECK:    ret float [[HLSL_SELECT_I]]
 //
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) float @_Z18test_refract_floatfff(
+//
+// SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) float @_Z18test_refract_floatfff(
 // SPVCHECK-SAME: float noundef nofpclass(nan inf) [[I:%.*]], float noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
 // SPVCHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
 // SPVCHECK:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
-// SPVCHECK:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL1_I]], %{{.*}}
-// SPVCHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL2_I]], %{{.*}}
-// SPVCHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL3_I]]
-// SPVCHECK:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
-// SPVCHECK:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL4_I]]
+// SPVCHECK:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// SPVCHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL2_I]]
+// SPVCHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL1_I]], [[SUB_I]]
+// SPVCHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL3_I]]
+// SPVCHECK:    [[MUL5_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
 // SPVCHECK:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
-// SPVCHECK:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
-// SPVCHECK:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL7_I]], %{{.*}}
-// SPVCHECK:    [[TMP2:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.sqrt.f32(float %{{.*}})
-// SPVCHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn float [[MUL8_I]], [[TMP2]]
-// SPVCHECK:    [[MUL9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ADD_I]], %{{.*}}
-// SPVCHECK:    [[SUB10_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float [[MUL6_I]], [[MUL9_I]]
+// SPVCHECK:    [[TMP18:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.sqrt.f32(float %{{.*}})
+// SPVCHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn float [[MUL6_I]], [[TMP18]]
+// SPVCHECK:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ADD_I]], %{{.*}}
+// SPVCHECK:    [[SUB8_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float [[MUL5_I]], [[MUL7_I]]
 // SPVCHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float %{{.*}}, 0.000000e+00
 // SPVCHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CMP_I]], float 0.000000e+00, float %{{.*}}
 // SPVCHECK:    ret float [[HLSL_SELECT_I]]
@@ -193,96 +180,90 @@ float test_refract_float(float I, float N, float ETA) {
     return refract(I, N, ETA);
 }
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x float> @_Z19test_refract_float2Dv2_fS_f(
+// CHECK-LABEL: define hidden noundef nofpclass(nan inf) <2 x float> @_Z19test_refract_float2Dv2_fS_f(
 // CHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[I:%.*]], <2 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK:  [[ENTRY:.*:]]
-// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
 // CHECK:    [[HLSL_DOT_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}})
-// CHECK:    [[HLSL_DOT1_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}})
-// CHECK:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[HLSL_DOT1_I]]
-// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_2_I]]
-// CHECK:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
-// CHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_3_I]]
-// CHECK:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}}, %{{.*}}
-// CHECK:    [[HLSL_DOT8_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}})
-// CHECK:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, [[HLSL_DOT8_I]]
-// CHECK:    [[TMP0:%.*]] = call reassoc nnan ninf nsz arcp afn <2 x float> @llvm.sqrt.v2f32(<2 x float> %{{.*}})
-// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}}, [[TMP0]]
-// CHECK:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> [[ADD_I]], %{{.*}}
-// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x float> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// CHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}}, %{{.*}}
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x float> splat (float 1.000000e+00), [[MUL3_I]]
+// CHECK:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}}, [[SUB_I]]
+// CHECK:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x float> splat (float 1.000000e+00), [[MUL4_I]]
+// CHECK:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}}, %{{.*}}
+// CHECK:    [[MUL11_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}}, %{{.*}}
+// CHECK:    [[TMP17:%.*]] = call reassoc nnan ninf nsz arcp afn <2 x float> @llvm.sqrt.v2f32(<2 x float> %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <2 x float> [[MUL11_I]], [[TMP17]]
+// CHECK:    [[MUL12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x float> [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <2 x float> [[MUL8_I]], [[MUL12_I]]
 // CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt  <2 x float> %{{.*}}, zeroinitializer
-// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 %{{.*}}, <2 x float> zeroinitializer, <2 x float> %{{.*}}
+// CHECK:    [[CAST:%.*]] = extractelement <2 x i1> [[CMP_I]], i32 0
+// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CAST]], <2 x float> zeroinitializer, <2 x float> %{{.*}}
 // CHECK:    ret <2 x float> [[HLSL_SELECT_I]]
 //
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x float> @_Z19test_refract_float2Dv2_fS_f(
+// SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <2 x float> @_Z19test_refract_float2Dv2_fS_f(
 // SPVCHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[I:%.*]], <2 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn float %{{.*}} to double
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <2 x float> @llvm.spv.refract.v2f32.f64(<2 x float> %{{.*}}, <2 x float> %{{.*}}, double [[CONV_I]])
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <2 x float> @llvm.spv.refract.v2f32.f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, float %{{.*}})
 // SPVCHECK:    ret <2 x float> [[SPV_REFRACT_I]]
 //
 float2 test_refract_float2(float2 I, float2 N, float ETA) {
     return refract(I, N, ETA);
 }
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x float> @_Z19test_refract_float3Dv3_fS_f(
+// CHECK-LABEL: define hidden noundef nofpclass(nan inf) <3 x float> @_Z19test_refract_float3Dv3_fS_f(
 // CHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[I:%.*]], <3 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
 // CHECK:    [[HLSL_DOT_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}})
-// CHECK:    [[HLSL_DOT1_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}})
-// CHECK:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[HLSL_DOT1_I]]
-// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_2_I]]
-// CHECK:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
-// CHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_3_I]]
-// CHECK:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> %{{.*}}, %{{.*}}
-// CHECK:    [[HLSL_DOT8_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}})
-// CHECK:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, [[HLSL_DOT8_I]]
-// CHECK:    [[TMP0:%.*]] = call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.sqrt.v3f32(<3 x float> %{{.*}})
-// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x float> %{{.*}}, [[TMP0]]
-// CHECK:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> [[ADD_I]], %{{.*}}
-// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x float> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// CHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> %{{.*}}, %{{.*}}
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x float> splat (float 1.000000e+00), [[MUL3_I]]
+// CHECK:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> %{{.*}}, [[SUB_I]]
+// CHECK:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x float> splat (float 1.000000e+00), [[MUL4_I]]
+// CHECK:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> %{{.*}}, %{{.*}}
+// CHECK:    [[MUL11_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> %{{.*}}, %{{.*}}
+// CHECK:    [[TMP17:%.*]] = call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.sqrt.v3f32(<3 x float> %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <3 x float> [[MUL11_I]], [[TMP17]]
+// CHECK:    [[MUL12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <3 x float> [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <3 x float> [[MUL8_I]], [[MUL12_I]]
 // CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt  <3 x float> %{{.*}}, zeroinitializer
-// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 %{{.*}}, <3 x float> zeroinitializer, <3 x float> %{{.*}}
+// CHECK:    [[CAST:%.*]] = extractelement <3 x i1> [[CMP_I]], i32 0
+// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CAST]], <3 x float> zeroinitializer, <3 x float> %{{.*}}
 // CHECK:    ret <3 x float> [[HLSL_SELECT_I]]
 //
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <3 x float> @_Z19test_refract_float3Dv3_fS_f(
+// SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <3 x float> @_Z19test_refract_float3Dv3_fS_f(
 // SPVCHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[I:%.*]], <3 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn float %{{.*}} to double
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <3 x float> @llvm.spv.refract.v3f32.f64(<3 x float> %{{.*}}, <3 x float> %{{.*}}, double [[CONV_I]])
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <3 x float> @llvm.spv.refract.v3f32.f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, float %{{.*}})
 // SPVCHECK:    ret <3 x float> [[SPV_REFRACT_I]]
 //
 float3 test_refract_float3(float3 I, float3 N, float ETA) {
     return refract(I, N, ETA);
 }
 
-// CHECK-LABEL: define noundef nofpclass(nan inf) <4 x float> @_Z19test_refract_float4Dv4_fS_f
+// CHECK-LABEL: define hidden noundef nofpclass(nan inf) <4 x float> @_Z19test_refract_float4Dv4_fS_f
 // CHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[I:%.*]], <4 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
 // CHECK:    [[HLSL_DOT_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}})
-// CHECK:    [[HLSL_DOT1_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}})
-// CHECK:    [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[HLSL_DOT_I]], [[HLSL_DOT1_I]]
-// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_2_I]]
-// CHECK:    [[MUL_3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL_I]], [[SUB_I]]
-// CHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL_3_I]]
-// CHECK:    [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> %{{.*}}, %{{.*}}
-// CHECK:    [[HLSL_DOT8_I:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.dx.fdot.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}})
-// CHECK:    [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, [[HLSL_DOT8_I]]
-// CHECK:    [[TMP0:%.*]] = call reassoc nnan ninf nsz arcp afn <4 x float> @llvm.sqrt.v4f32(<4 x float> %{{.*}})
-// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <4 x float> %{{.*}}, [[TMP0]]
-// CHECK:    [[MUL_12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> [[ADD_I]], %{{.*}}
-// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x float> [[MUL_7_I]], [[MUL_12_I]]
+// CHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
+// CHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> %{{.*}}, %{{.*}}
+// CHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x float> splat (float 1.000000e+00), [[MUL3_I]]
+// CHECK:    [[MUL4_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> %{{.*}}, [[SUB_I]]
+// CHECK:    [[SUB5_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x float> splat (float 1.000000e+00), [[MUL4_I]]
+// CHECK:    [[MUL8_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> %{{.*}}, %{{.*}}
+// CHECK:    [[MUL11_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> %{{.*}}, %{{.*}}
+// CHECK:    [[TMP17:%.*]] = call reassoc nnan ninf nsz arcp afn <4 x float> @llvm.sqrt.v4f32(<4 x float> %{{.*}})
+// CHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn <4 x float> [[MUL11_I]], [[TMP17]]
+// CHECK:    [[MUL12_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <4 x float> [[ADD_I]], %{{.*}}
+// CHECK:    [[SUB13_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn <4 x float> [[MUL8_I]], [[MUL12_I]]
 // CHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt  <4 x float> %{{.*}}, zeroinitializer
-// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 %{{.*}}, <4 x float> zeroinitializer, <4 x float> %{{.*}}
+// CHECK:    [[CAST:%.*]] = extractelement <4 x i1> [[CMP_I]], i32 0
+// CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CAST]], <4 x float> zeroinitializer, <4 x float> %{{.*}}
 // CHECK:    ret <4 x float> [[HLSL_SELECT_I]]
 
-// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <4 x float> @_Z19test_refract_float4Dv4_fS_f(
+// SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <4 x float> @_Z19test_refract_float4Dv4_fS_f(
 // SPVCHECK-SAME: <4 x float> noundef nofpclass(nan inf) %{{.*}}, <4 x float> noundef nofpclass(nan inf) %{{.*}}, float noundef nofpclass(nan inf) %{{.*}}) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn float %{{.*}} to double
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32.f64(<4 x float> %{{.*}}, <4 x float> %{{.*}}, double [[CONV_I]])
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32.f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, float %{{.*}})
 // SPVCHECK:    ret <4 x float> [[SPV_REFRACT_I]]
 //
 float4 test_refract_float4(float4 I, float4 N, float ETA) {
diff --git a/clang/test/CodeGenSPIRV/Builtins/refract.c b/clang/test/CodeGenSPIRV/Builtins/refract.c
index 82f620fa293de..f477f532ffb6f 100644
--- a/clang/test/CodeGenSPIRV/Builtins/refract.c
+++ b/clang/test/CodeGenSPIRV/Builtins/refract.c
@@ -1,5 +1,3 @@
-// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
-
 // RUN: %clang_cc1 -O1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s
 
 typedef float float2 __attribute__((ext_vector_type(2)));
@@ -9,8 +7,7 @@ typedef float float4 __attribute__((ext_vector_type(4)));
 // CHECK-LABEL: define spir_func <2 x float> @test_refract_float2(
 // CHECK-SAME: <2 x float> noundef [[I:%.*]], <2 x float> noundef [[N:%.*]], float noundef [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[CONV:%.*]] = fpext float [[ETA]] to double
-// CHECK:    [[SPV_REFRACT:%.*]] = tail call <2 x float> @llvm.spv.refract.v2f32.f64(<2 x float> [[I]], <2 x float> [[N]], double [[CONV]])
+// CHECK:    [[SPV_REFRACT:%.*]] = tail call <2 x float> @llvm.spv.refract.v2f32.f32(<2 x float> [[I]], <2 x float> [[N]], float [[ETA]])
 // CHECK-NEXT:    ret <2 x float> [[SPV_REFRACT]]
 //
 float2 test_refract_float2(float2 I, float2 N, float eta) { return __builtin_spirv_refract(I, N, eta); }
@@ -18,8 +15,7 @@ float2 test_refract_float2(float2 I, float2 N, float eta) { return __builtin_spi
 // CHECK-LABEL: define spir_func <3 x float> @test_refract_float3(
 // CHECK-SAME: <3 x float> noundef [[I:%.*]], <3 x float> noundef [[N:%.*]], float noundef [[ETA:%.*]]) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[CONV:%.*]] = fpext float [[ETA]] to double
-// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <3 x float> @llvm.spv.refract.v3f32.f64(<3 x float> [[I]], <3 x float> [[N]], double [[CONV]])
+// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <3 x float> @llvm.spv.refract.v3f32.f32(<3 x float> [[I]], <3 x float> [[N]], float [[ETA]])
 // CHECK-NEXT:    ret <3 x float> [[SPV_REFRACT]]
 //
 float3 test_refract_float3(float3 I, float3 N, float eta) { return __builtin_spirv_refract(I, N, eta); }
@@ -27,8 +23,7 @@ float3 test_refract_float3(float3 I, float3 N, float eta) { return __builtin_spi
 // CHECK-LABEL: define spir_func <4 x float> @test_refract_float4(
 // CHECK-SAME: <4 x float> noundef [[I:%.*]], <4 x float> noundef [[N:%.*]], float noundef [[ETA:%.*]]) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[CONV:%.*]] = fpext float [[ETA]] to double
-// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <4 x float> @llvm.spv.refract.v4f32.f64(<4 x float> [[I]], <4 x float> [[N]], double [[CONV]])
+// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <4 x float> @llvm.spv.refract.v4f32.f32(<4 x float> [[I]], <4 x float> [[N]], float [[ETA]])
 // CHECK-NEXT:    ret <4 x float> [[SPV_REFRACT]]
 //
 float4 test_refract_float4(float4 I, float4 N, float eta) { return __builtin_spirv_refract(I, N, eta); }
diff --git a/clang/test/SemaSPIRV/BuiltIns/refract-errors.c b/clang/test/SemaSPIRV/BuiltIns/refract-errors.c
index 9f6589bb05b6a..775f82a858dc2 100644
--- a/clang/test/SemaSPIRV/BuiltIns/refract-errors.c
+++ b/clang/test/SemaSPIRV/BuiltIns/refract-errors.c
@@ -2,11 +2,6 @@
 
 typedef float float2 __attribute__((ext_vector_type(2)));
 
-float2 test_no_second_arg(float2 p0) {
-  return __builtin_spirv_refract(p0);
-  // expected-error at -1 {{too few arguments to function call, expected 3, have 1}}
-}
-
 float2 test_no_third_arg(float2 p0) {
   return __builtin_spirv_refract(p0, p0);
   // expected-error at -1 {{too few arguments to function call, expected 3, have 2}}
@@ -19,10 +14,10 @@ float2 test_too_many_arg(float2 p0, float p1) {
 
 float test_double_scalar_inputs(double p0, double p1, double p2) {
   return __builtin_spirv_refract(p0, p1, p2);
-  //  expected-error at -1 {{passing 'double' to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(double)))) double' (vector of 2 'double' values)}}
+  //  expected-error at -1 {{1st argument must be a scalar or vector of 16 or 32 bit floating-point types (was 'double')}}
 }
 
 float test_int_scalar_inputs(int p0, int p1, int p2) {
   return __builtin_spirv_refract(p0, p1, p2);
-  //  expected-error at -1 {{passing 'int' to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(int)))) int' (vector of 2 'int' values)}}
+  //  expected-error at -1 {{1st argument must be a scalar or vector of 16 or 32 bit floating-point types (was 'int')}}
 }
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 8b1864e7c6710..f30b606323483 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -865,27 +865,8 @@ CallInst *IRBuilderBase::CreateIntrinsic(Type *RetTy, Intrinsic::ID ID,
 
   SmallVector<Type *> ArgTys;
   ArgTys.reserve(Args.size());
-  int i =0;
-  Type * Ity;
-  Type * Nty;
-  Type * etaty;
-
-  for (auto &I : Args) {
-    if(i ==0)
-      Ity = I->getType();
-    if(i ==1)
-      Nty = I->getType();
-    if(i ==2)
-      etaty = I->getType();
+  for (auto &I : Args)
     ArgTys.push_back(I->getType());
-    i++;
-  }
-  //assert(Ity == RetTy);
-  //assert(Nty == RetTy);
-  assert(Nty == Ity);
-
-
-
   FunctionType *FTy = FunctionType::get(RetTy, ArgTys, false);
   SmallVector<Type *> OverloadTys;
   Intrinsic::MatchIntrinsicTypesResult Res =
@@ -1281,4 +1262,4 @@ IRBuilderDefaultInserter::~IRBuilderDefaultInserter() = default;
 IRBuilderCallbackInserter::~IRBuilderCallbackInserter() = default;
 IRBuilderFolder::~IRBuilderFolder() = default;
 void ConstantFolder::anchor() {}
-void NoFolder::anchor() {}
+void NoFolder::anchor() {}
\ No newline at end of file
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
index 780f7bfe60db1..c7a973c311ef8 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
@@ -7,7 +7,6 @@
 ; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
 ; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
 ; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
-; CHECK-DAG: %[[#float_64:]] = OpTypeFloat 64
 ; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
 
 define noundef  <4 x half> @refract_half(<4 x half> noundef  %I, <4 x half> noundef  %N, half noundef  %ETA) {
@@ -15,11 +14,9 @@ entry:
   ; CHECK: %[[#]] = OpFunction %[[#vec4_float_16]] None %[[#]]
   ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]]
   ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_16]]
-  ; CHECK: %[[#arg2_float_16:]] = OpFunctionParameter %[[#float_16:]]
-  ; CHECK: %[[#arg2:]] = OpFConvert %[[#float_64:]] %[[#arg2_float_16:]]
+  ; CHECK: %[[#arg2:]] = OpFunctionParameter %[[#float_16:]]
   ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_16]] %[[#op_ext_glsl]] Refract %[[#arg0]] %[[#arg1]] %[[#arg2]]
-  %conv.i = fpext reassoc nnan ninf nsz arcp afn half %ETA to double
-  %spv.refract.i = tail call reassoc nnan ninf nsz arcp afn noundef <4 x half> @llvm.spv.refract.v4f16.f64(<4 x half> %I, <4 x half> %N, double %conv.i)
+  %spv.refract.i = tail call reassoc nnan ninf nsz arcp afn noundef <4 x half> @llvm.spv.refract.v4f16.f16(<4 x half> %I, <4 x half> %N, half %ETA)
   ret <4 x half> %spv.refract.i
 }
 
@@ -29,9 +26,11 @@ entry:
   ; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#]]
   ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
   ; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
-  ; CHECK: %[[#arg2_float_32:]] = OpFunctionParameter %[[#float_32:]]
-  ; CHECK: %[[#arg2:]] = OpFConvert %[[#float_64:]] %[[#arg2_float_32:]]
+  ; CHECK: %[[#arg2:]] = OpFunctionParameter %[[#float_32:]]
   ; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] Refract %[[#arg0]] %[[#arg1]] %[[#arg2]]
-  %spv.refract.i = tail call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32.f64(<4 x float> %I, <4 x float> %N, double %conv.i)
+  %spv.refract.i = tail call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32.f32(<4 x float> %I, <4 x float> %N, float %ETA)
   ret <4 x float> %spv.refract.i
 }
+
+declare <4 x half> @llvm.spv.refract.v4f16.f16(<4 x half>, <4 x half>, half)
+declare <4 x float> @llvm.spv.reflect.v4f32.f32(<4 x float>, <4 x float>, float)

>From 86d2f84347609059e777a7557df3d67ed7500592 Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Thu, 3 Jul 2025 18:03:56 -0700
Subject: [PATCH 16/23] Remove blank line

---
 clang/lib/Sema/SemaSPIRV.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 86ec569499a5c..1b4093065a63a 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -203,7 +203,6 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     TheCall->setType(RetTy);
     break;
   }
-
   case SPIRV::BI__builtin_spirv_reflect: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;

>From 17d1fbbb2e0d395cafe797ebc375dda381a3d4ef Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Mon, 7 Jul 2025 18:16:47 -0700
Subject: [PATCH 17/23] revert arg checks for other intrinsics

---
 clang/include/clang/Sema/Sema.h               | 15 ++--
 .../lib/Headers/hlsl/hlsl_intrinsic_helpers.h | 15 ----
 clang/lib/Sema/SemaChecking.cpp               | 45 ++--------
 clang/lib/Sema/SemaHLSL.cpp                   |  8 +-
 clang/lib/Sema/SemaSPIRV.cpp                  | 84 ++++++++++++++-----
 clang/test/CodeGenHLSL/builtins/reflect.hlsl  |  2 +-
 clang/test/CodeGenHLSL/intrinsic.ll           | 19 +++++
 llvm/lib/IR/IRBuilder.cpp                     |  2 +-
 8 files changed, 103 insertions(+), 87 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/intrinsic.ll

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 105ab804fffd0..255a5940ecb9c 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -2791,11 +2791,6 @@ class Sema final : public SemaBase {
 
   void CheckConstrainedAuto(const AutoType *AutoT, SourceLocation Loc);
 
-  /// CheckVectorArgs - Check that the arguments of a vector function call
-  bool CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck);
-
-  bool CheckVectorArgs(CallExpr *TheCall);
-
   bool CheckAllArgTypesAreCorrect(
       Sema *S, CallExpr *TheCall,
       llvm::ArrayRef<
@@ -2806,15 +2801,15 @@ class Sema final : public SemaBase {
       llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check);
 
   static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
-                                            int ArgOrdinal,
-                                            clang::QualType PassedType);
-  static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
                                              int ArgOrdinal,
                                              clang::QualType PassedType);
+  static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
+                                                    int ArgOrdinal,
+                                                    clang::QualType PassedType);
 
   static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
-                                                int ArgOrdinal,
-                                                clang::QualType PassedType);
+                                                   int ArgOrdinal,
+                                                   clang::QualType PassedType);
   /// BuiltinConstantArg - Handle a check if argument ArgNum of CallExpr
   /// TheCall is a constant expression.
   bool BuiltinConstantArg(CallExpr *TheCall, int ArgNum, llvm::APSInt &Result);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index f6acb1cea2594..39d2ae2b85076 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -92,21 +92,6 @@ constexpr T refract_vec_impl(T I, T N, U Eta) {
 #endif
 }
 
-/*
-template <typename T, int L>
-constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
-#if (__has_builtin(__builtin_spirv_refract) && is_vector<T>))
-  return __builtin_spirv_refract(I, N, Eta);
-#else
-  T Mul = dot(N, I);
-  vector<T, L> K = 1 - Eta * Eta * (1 - Mul * Mul);
-  vector<T, L> Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
-  return select<vector<T, L>>(K < 0, vector<T, L>(0), Result);
-#endif
-}
-
-*/
-
 template <typename T> constexpr T fmod_impl(T X, T Y) {
 #if !defined(__DIRECTX__)
   return __builtin_elementwise_fmod(X, Y);
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 98bca59f14ecd..1660c63baa0bb 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -16152,28 +16152,6 @@ void Sema::CheckTCBEnforcement(const SourceLocation CallExprLoc,
   }
 }
 
-bool Sema::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
-  for (unsigned i = 0; i < NumArgsToCheck; ++i) {
-    ExprResult Arg = TheCall->getArg(i);
-    QualType ArgTy = Arg.get()->getType();
-    auto *VTy = ArgTy->getAs<VectorType>();
-    if (VTy == nullptr) {
-      SemaRef.Diag(Arg.get()->getBeginLoc(),
-                   diag::err_typecheck_convert_incompatible)
-          << ArgTy
-          << SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1
-          << 0 << 0;
-      return true;
-    }
-  }
-  return false;
-}
-
-bool Sema::CheckVectorArgs(CallExpr *TheCall) {
-  return CheckVectorArgs(TheCall, TheCall->getNumArgs());
-}
-
-
 bool Sema::CheckAllArgTypesAreCorrect(
     Sema *S, CallExpr *TheCall,
     llvm::ArrayRef<
@@ -16211,8 +16189,8 @@ bool Sema::CheckAllArgTypesAreCorrect(
 }
 
 bool Sema::CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
-                                           int ArgOrdinal,
-                                           clang::QualType PassedType) {
+                                          int ArgOrdinal,
+                                          clang::QualType PassedType) {
   clang::QualType BaseType =
       PassedType->isVectorType()
           ? PassedType->castAs<clang::VectorType>()->getElementType()
@@ -16225,8 +16203,8 @@ bool Sema::CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
 }
 
 bool Sema::CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
-                                                  int ArgOrdinal,
-                                                  clang::QualType PassedType) {
+                                                 int ArgOrdinal,
+                                                 clang::QualType PassedType) {
   const auto *VecTy = PassedType->getAs<VectorType>();
 
   clang::QualType BaseType =
@@ -16240,19 +16218,14 @@ bool Sema::CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
   return false;
 }
 
-bool Sema::CheckFloatOrHalfScalarRepresentation(
-    Sema *S, SourceLocation Loc,
-                                                 int ArgOrdinal,
-                                                 clang::QualType PassedType) {
+bool Sema::CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
+                                                int ArgOrdinal,
+                                                clang::QualType PassedType) {
   const auto *VecTy = PassedType->getAs<VectorType>();
 
-  clang::QualType BaseType =
-      PassedType->isVectorType()
-          ? PassedType->castAs<clang::VectorType>()->getElementType()
-          : PassedType;
-  if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
+  if (VecTy || !PassedType->isHalfType() && !PassedType->isFloat32Type())
     return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
-           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << ArgOrdinal << /* scalar */ 1 << /* no int */ 0
            << /* half or float */ 2 << PassedType;
   return false;
 }
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 991d330edfb6f..896d4da2d8a60 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2452,13 +2452,13 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
 }
 
 static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
-                                           int ArgOrdinal,
-                                           clang::QualType PassedType) {
+                                                  int ArgOrdinal,
+                                                  clang::QualType PassedType) {
   const auto *VecTy = PassedType->getAs<VectorType>();
 
-  clang::QualType BaseType = 
+  clang::QualType BaseType =
       PassedType->isVectorType()
-        ? PassedType->castAs<clang::VectorType>()->getElementType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
           : PassedType;
   if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
     return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 1b4093065a63a..f1704986f069a 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -157,25 +157,81 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
 
-    // Use the helper function to check both arguments
-    if (SemaRef.CheckVectorArgs(TheCall))
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+    auto *VTyA = ArgTyA->getAs<VectorType>();
+    if (VTyA == nullptr) {
+      SemaRef.Diag(A.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyA
+          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
+          << 0 << 0;
       return true;
+    }
 
-    QualType RetTy =
-        TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
+    ExprResult B = TheCall->getArg(1);
+    QualType ArgTyB = B.get()->getType();
+    auto *VTyB = ArgTyB->getAs<VectorType>();
+    if (VTyB == nullptr) {
+      SemaRef.Diag(A.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyB
+          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+
+    QualType RetTy = VTyA->getElementType();
     TheCall->setType(RetTy);
     break;
   }
   case SPIRV::BI__builtin_spirv_length: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+    auto *VTy = ArgTyA->getAs<VectorType>();
+    if (VTy == nullptr) {
+      SemaRef.Diag(A.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyA
+          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+    QualType RetTy = VTy->getElementType();
+    TheCall->setType(RetTy);
+    break;
+  }
+  case SPIRV::BI__builtin_spirv_reflect: {
+    if (SemaRef.checkArgCount(TheCall, 2))
+      return true;
 
-    // Use the helper function to check the argument
-    if (SemaRef.CheckVectorArgs(TheCall))
+    ExprResult A = TheCall->getArg(0);
+    QualType ArgTyA = A.get()->getType();
+    auto *VTyA = ArgTyA->getAs<VectorType>();
+    if (VTyA == nullptr) {
+      SemaRef.Diag(A.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyA
+          << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
+          << 0 << 0;
+      return true;
+    }
+
+    ExprResult B = TheCall->getArg(1);
+    QualType ArgTyB = B.get()->getType();
+    auto *VTyB = ArgTyB->getAs<VectorType>();
+    if (VTyB == nullptr) {
+      SemaRef.Diag(A.get()->getBeginLoc(),
+                   diag::err_typecheck_convert_incompatible)
+          << ArgTyB
+          << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
+          << 0 << 0;
       return true;
+    }
 
-    QualType RetTy =
-        TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
+    QualType RetTy = ArgTyA;
     TheCall->setType(RetTy);
     break;
   }
@@ -203,18 +259,6 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     TheCall->setType(RetTy);
     break;
   }
-  case SPIRV::BI__builtin_spirv_reflect: {
-    if (SemaRef.checkArgCount(TheCall, 2))
-      return true;
-
-    // Use the helper function to check both arguments
-    if (SemaRef.CheckVectorArgs(TheCall))
-      return true;
-
-    QualType RetTy = TheCall->getArg(0)->getType();
-    TheCall->setType(RetTy);
-    break;
-  }
   case SPIRV::BI__builtin_spirv_smoothstep: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/reflect.hlsl b/clang/test/CodeGenHLSL/builtins/reflect.hlsl
index 2d79ef45dabdc..65fefd801ffed 100644
--- a/clang/test/CodeGenHLSL/builtins/reflect.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/reflect.hlsl
@@ -174,4 +174,4 @@ float3 test_reflect_float3(float3 I, float3 N) {
 //
 float4 test_reflect_float4(float4 I, float4 N) {
     return reflect(I, N);
-}
\ No newline at end of file
+}
diff --git a/clang/test/CodeGenHLSL/intrinsic.ll b/clang/test/CodeGenHLSL/intrinsic.ll
new file mode 100644
index 0000000000000..104a6645ce7e3
--- /dev/null
+++ b/clang/test/CodeGenHLSL/intrinsic.ll
@@ -0,0 +1,19 @@
+; Function Attrs: alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none)
+define hidden spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(half noundef nofpclass(nan inf) %I, half noundef nofpclass(nan inf) %N, half noundef nofpclass(nan inf) %ETA) local_unnamed_addr #0 {
+entry:
+  %mul.i = fmul reassoc nnan ninf nsz arcp afn half %N, %I
+  %mul1.i = fmul reassoc nnan ninf nsz arcp afn half %ETA, %ETA
+  %mul2.i = fmul reassoc nnan ninf nsz arcp afn half %mul.i, %mul.i
+  %sub.i = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, %mul2.i
+  %mul3.i = fmul reassoc nnan ninf nsz arcp afn half %mul1.i, %sub.i
+  %sub4.i = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, %mul3.i
+  %mul5.i = fmul reassoc nnan ninf nsz arcp afn half %ETA, %I
+  %mul6.i = fmul reassoc nnan ninf nsz arcp afn half %ETA, %mul.i
+  %0 = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half %sub4.i)
+  %add.i = fadd reassoc nnan ninf nsz arcp afn half %0, %mul6.i
+  %mul7.i = fmul reassoc nnan ninf nsz arcp afn half %add.i, %N
+  %sub8.i = fsub reassoc nnan ninf nsz arcp afn half %mul5.i, %mul7.i
+  %cmp.i = fcmp reassoc nnan ninf nsz arcp afn olt half %sub4.i, 0xH0000
+  %hlsl.select.i = select reassoc nnan ninf nsz arcp afn i1 %cmp.i, half 0xH0000, half %sub8.i
+  ret half %hlsl.select.i
+}
\ No newline at end of file
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index f30b606323483..28037d7ec5616 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -1262,4 +1262,4 @@ IRBuilderDefaultInserter::~IRBuilderDefaultInserter() = default;
 IRBuilderCallbackInserter::~IRBuilderCallbackInserter() = default;
 IRBuilderFolder::~IRBuilderFolder() = default;
 void ConstantFolder::anchor() {}
-void NoFolder::anchor() {}
\ No newline at end of file
+void NoFolder::anchor() {}

>From 7bdad404fa428d2952d58c245aa89fad691f4779 Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Tue, 8 Jul 2025 15:10:01 -0700
Subject: [PATCH 18/23] Update Refract intrinsic

---
 clang/include/clang/Sema/Sema.h               | 21 +----
 clang/lib/Headers/hlsl/hlsl_detail.h          |  2 +-
 .../lib/Headers/hlsl/hlsl_intrinsic_helpers.h | 16 +---
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  4 +-
 clang/lib/Sema/SemaChecking.cpp               | 78 -------------------
 clang/lib/Sema/SemaHLSL.cpp                   | 75 +++---------------
 clang/lib/Sema/SemaSPIRV.cpp                  | 63 +++++++++++++--
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  6 +-
 8 files changed, 80 insertions(+), 185 deletions(-)

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 255a5940ecb9c..6ad6ff460dcd7 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -2791,25 +2791,6 @@ class Sema final : public SemaBase {
 
   void CheckConstrainedAuto(const AutoType *AutoT, SourceLocation Loc);
 
-  bool CheckAllArgTypesAreCorrect(
-      Sema *S, CallExpr *TheCall,
-      llvm::ArrayRef<
-          llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
-          Checks);
-  bool CheckAllArgTypesAreCorrect(
-      Sema *S, CallExpr *TheCall,
-      llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check);
-
-  static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
-                                             int ArgOrdinal,
-                                             clang::QualType PassedType);
-  static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
-                                                    int ArgOrdinal,
-                                                    clang::QualType PassedType);
-
-  static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
-                                                   int ArgOrdinal,
-                                                   clang::QualType PassedType);
   /// BuiltinConstantArg - Handle a check if argument ArgNum of CallExpr
   /// TheCall is a constant expression.
   bool BuiltinConstantArg(CallExpr *TheCall, int ArgNum, llvm::APSInt &Result);
@@ -15506,4 +15487,4 @@ void Sema::PragmaStack<Sema::AlignPackInfo>::Act(SourceLocation PragmaLocation,
 
 } // end namespace clang
 
-#endif
+#endif
\ No newline at end of file
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 96e101a1e3aa8..22250960b3289 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -55,7 +55,7 @@ template <typename T, int N> struct is_vector<vector<T, N>> {
 
 template <typename T, int N>
 using HLSL_FIXED_VECTOR =
-    vector<__detail::enable_if_t<(N > 1 && N <= 4), T>, N>;
+    vector<__detail::enable_if_t<(N >= 1 && N <= 4), T>, N>;
 
 } // namespace __detail
 } // namespace hlsl
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 39d2ae2b85076..2ded062ea0d27 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -71,25 +71,15 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 #endif
 }
 
-template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
-  T Mul = N * I;
-  T K = 1 - Eta * Eta * (1 - (Mul * Mul));
-  T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
-  return select<T>(K < 0, static_cast<T>(0), Result);
-}
-
-template <typename T, typename U>
-constexpr T refract_vec_impl(T I, T N, U Eta) {
+template <typename T, typename U> constexpr T refract_impl(T I, T N, U Eta) {
 #if (__has_builtin(__builtin_spirv_refract))
-  if (is_vector<T>::value) {
+  if (is_vector<T>::value)
     return __builtin_spirv_refract(I, N, Eta);
-  }
-#else
+#endif
   T Mul = dot(N, I);
   T K = 1 - Eta * Eta * (1 - Mul * Mul);
   T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
   return select<T>(K < 0, static_cast<T>(0), Result);
-#endif
 }
 
 template <typename T> constexpr T fmod_impl(T X, T Y) {
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 8c262ffce25f1..499a05328ee4f 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -524,14 +524,14 @@ _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
 const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
     __detail::HLSL_FIXED_VECTOR<half, L> I,
     __detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
-  return __detail::refract_vec_impl(I, N, eta);
+  return __detail::refract_impl(I, N, eta);
 }
 
 template <int L>
 const inline __detail::HLSL_FIXED_VECTOR<float, L>
 refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
         __detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
-  return __detail::refract_vec_impl(I, N, eta);
+  return __detail::refract_impl(I, N, eta);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 1660c63baa0bb..dd5b710d7e1d4 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -16151,81 +16151,3 @@ void Sema::CheckTCBEnforcement(const SourceLocation CallExprLoc,
     }
   }
 }
-
-bool Sema::CheckAllArgTypesAreCorrect(
-    Sema *S, CallExpr *TheCall,
-    llvm::ArrayRef<
-        llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
-        Checks) {
-  unsigned NumArgs = TheCall->getNumArgs();
-  if (Checks.size() == 1) {
-    // Apply the single check to all arguments
-    for (unsigned I = 0; I < NumArgs; ++I) {
-      Expr *Arg = TheCall->getArg(I);
-      if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
-        return true;
-    }
-    return false;
-  } else if (Checks.size() == NumArgs) {
-    // Apply each check to the corresponding argument
-    for (unsigned I = 0; I < NumArgs; ++I) {
-      Expr *Arg = TheCall->getArg(I);
-      if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
-        return true;
-    }
-    return false;
-  } else {
-    // Mismatch: error or fallback
-    S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
-        << NumArgs << Checks.size();
-    return true;
-  }
-}
-
-bool Sema::CheckAllArgTypesAreCorrect(
-    Sema *S, CallExpr *TheCall,
-    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
-  return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
-}
-
-bool Sema::CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
-                                          int ArgOrdinal,
-                                          clang::QualType PassedType) {
-  clang::QualType BaseType =
-      PassedType->isVectorType()
-          ? PassedType->castAs<clang::VectorType>()->getElementType()
-          : PassedType;
-  if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
-    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
-           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
-           << /* half or float */ 2 << PassedType;
-  return false;
-}
-
-bool Sema::CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
-                                                 int ArgOrdinal,
-                                                 clang::QualType PassedType) {
-  const auto *VecTy = PassedType->getAs<VectorType>();
-
-  clang::QualType BaseType =
-      PassedType->isVectorType()
-          ? PassedType->castAs<clang::VectorType>()->getElementType()
-          : PassedType;
-  if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
-    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
-           << ArgOrdinal << /* vector of */ 5 << /* no int */ 0
-           << /* half or float */ 2 << PassedType;
-  return false;
-}
-
-bool Sema::CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
-                                                int ArgOrdinal,
-                                                clang::QualType PassedType) {
-  const auto *VecTy = PassedType->getAs<VectorType>();
-
-  if (VecTy || !PassedType->isHalfType() && !PassedType->isFloat32Type())
-    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
-           << ArgOrdinal << /* scalar */ 1 << /* no int */ 0
-           << /* half or float */ 2 << PassedType;
-  return false;
-}
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 896d4da2d8a60..505b11f722d7b 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2401,40 +2401,17 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
   return false;
 }
 
-bool CheckAllArgTypesAreCorrect(
+static bool CheckAllArgTypesAreCorrect(
     Sema *S, CallExpr *TheCall,
-    llvm::ArrayRef<
-        llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
-        Checks) {
-  unsigned NumArgs = TheCall->getNumArgs();
-  if (Checks.size() == 1) {
-    // Apply the single check to all arguments
-    for (unsigned I = 0; I < NumArgs; ++I) {
-      Expr *Arg = TheCall->getArg(I);
-      if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
-        return true;
-    }
-    return false;
-  } else if (Checks.size() == NumArgs) {
-    // Apply each check to the corresponding argument
-    for (unsigned I = 0; I < NumArgs; ++I) {
-      Expr *Arg = TheCall->getArg(I);
-      if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
-        return true;
-    }
-    return false;
-  } else {
-    // Mismatch: error or fallback
-    S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
-        << NumArgs << Checks.size();
-    return true;
+    llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
+                            clang::QualType PassedType)>
+        Check) {
+  for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
+    Expr *Arg = TheCall->getArg(I);
+    if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+      return true;
   }
-}
-
-bool CheckAllArgTypesAreCorrect(
-    Sema *S, CallExpr *TheCall,
-    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
-  return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
+  return false;
 }
 
 static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
@@ -2451,38 +2428,6 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
   return false;
 }
 
-static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
-                                                  int ArgOrdinal,
-                                                  clang::QualType PassedType) {
-  const auto *VecTy = PassedType->getAs<VectorType>();
-
-  clang::QualType BaseType =
-      PassedType->isVectorType()
-          ? PassedType->castAs<clang::VectorType>()->getElementType()
-          : PassedType;
-  if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
-    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
-           << ArgOrdinal << /* vector of */ 5 << /* no int */ 0
-           << /* half or float */ 2 << PassedType;
-  return false;
-}
-
-static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
-                                                 int ArgOrdinal,
-                                                 clang::QualType PassedType) {
-  const auto *VecTy = PassedType->getAs<VectorType>();
-
-  clang::QualType BaseType =
-      PassedType->isVectorType()
-          ? PassedType->castAs<clang::VectorType>()->getElementType()
-          : PassedType;
-  if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
-    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
-           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
-           << /* half or float */ 2 << PassedType;
-  return false;
-}
-
 static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
                                   unsigned ArgIndex) {
   auto *Arg = TheCall->getArg(ArgIndex);
@@ -4050,4 +3995,4 @@ bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
   }
   Init = C;
   return true;
-}
+}
\ No newline at end of file
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index f1704986f069a..4b99983e823a2 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -46,6 +46,59 @@ static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
   return false;
 }
 
+static bool CheckAllArgTypesAreCorrect(
+    Sema *S, CallExpr *TheCall,
+    llvm::ArrayRef<
+        llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
+        Checks) {
+  unsigned NumArgs = TheCall->getNumArgs();
+  assert(Checks.size() == NumArgs &&
+         "Wrong number of checks for Number of args.");
+  // Apply each check to the corresponding argument
+  for (unsigned I = 0; I < NumArgs; ++I) {
+    Expr *Arg = TheCall->getArg(I);
+    if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
+      return true;
+  }
+  return false;
+}
+
+static bool CheckAllArgTypesAreCorrect(
+    Sema *S, CallExpr *TheCall,
+    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
+  return CheckAllArgTypesAreCorrect(
+      S, TheCall,
+      SmallVector<
+          llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>, 4>(
+          TheCall->getNumArgs(), Check));
+}
+
+static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
+                                           int ArgOrdinal,
+                                           clang::QualType PassedType) {
+  clang::QualType BaseType =
+      PassedType->isVectorType()
+          ? PassedType->castAs<clang::VectorType>()->getElementType()
+          : PassedType;
+  if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
+static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
+                                                 int ArgOrdinal,
+                                                 clang::QualType PassedType) {
+  const auto *VecTy = PassedType->getAs<VectorType>();
+
+  if (VecTy || (!PassedType->isHalfType() && !PassedType->isFloat32Type()))
+    return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
+           << ArgOrdinal << /* scalar */ 1 << /* no int */ 0
+           << /* half or float */ 2 << PassedType;
+  return false;
+}
+
 static std::optional<int>
 processConstant32BitIntArgument(Sema &SemaRef, CallExpr *Call, int Argument) {
   ExprResult Arg =
@@ -240,11 +293,11 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
       return true;
 
     llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
-        ChecksArr[] = {Sema::CheckFloatOrHalfVectorsRepresentation,
-                       Sema::CheckFloatOrHalfVectorsRepresentation,
-                       Sema::CheckFloatOrHalfScalarRepresentation};
-    if (SemaRef.CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
-                                           llvm::ArrayRef(ChecksArr)))
+        ChecksArr[] = {CheckFloatOrHalfRepresentation,
+                       CheckFloatOrHalfRepresentation,
+                       CheckFloatOrHalfScalarRepresentation};
+    if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
+                                   llvm::ArrayRef(ChecksArr)))
       return true;
 
     ExprResult C = TheCall->getArg(2);
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 2f20d11a49219..241fe59ee8620 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -75,7 +75,11 @@ let TargetPrefix = "spv" in {
     [IntrNoMem] >;
   def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
   def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
-  def int_spv_refract : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_anyfloat_ty], [IntrNoMem]>;
+  def int_spv_refract
+      : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+                              [llvm_anyfloat_ty, LLVMMatchType<0>,
+                               llvm_anyfloat_ty],
+                              [IntrNoMem]>;
   def int_spv_reflect : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
   def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
   def int_spv_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;

>From 39323c13b109c44be1b9edf89a5e67939e82b32d Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Tue, 8 Jul 2025 15:32:14 -0700
Subject: [PATCH 19/23] Update codegen-spirv test flag

---
 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
index c7a973c311ef8..58c86d9bc6aef 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
@@ -1,5 +1,5 @@
-; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: llc -O0 -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
 
 ; Make sure SPIRV operation function calls for refract are lowered correctly.
 

>From 729fbf315c399116f5be995f6e2b6989d13f86b6 Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Tue, 8 Jul 2025 15:44:43 -0700
Subject: [PATCH 20/23] remove typos

---
 clang/include/clang/Sema/Sema.h      |  2 +-
 clang/lib/Headers/hlsl/hlsl_detail.h |  2 +-
 clang/test/CodeGenHLSL/intrinsic.ll  | 19 -------------------
 3 files changed, 2 insertions(+), 21 deletions(-)
 delete mode 100644 clang/test/CodeGenHLSL/intrinsic.ll

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 6ad6ff460dcd7..3fe26f950ad51 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -15487,4 +15487,4 @@ void Sema::PragmaStack<Sema::AlignPackInfo>::Act(SourceLocation PragmaLocation,
 
 } // end namespace clang
 
-#endif
\ No newline at end of file
+#endif
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 22250960b3289..96e101a1e3aa8 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -55,7 +55,7 @@ template <typename T, int N> struct is_vector<vector<T, N>> {
 
 template <typename T, int N>
 using HLSL_FIXED_VECTOR =
-    vector<__detail::enable_if_t<(N >= 1 && N <= 4), T>, N>;
+    vector<__detail::enable_if_t<(N > 1 && N <= 4), T>, N>;
 
 } // namespace __detail
 } // namespace hlsl
diff --git a/clang/test/CodeGenHLSL/intrinsic.ll b/clang/test/CodeGenHLSL/intrinsic.ll
deleted file mode 100644
index 104a6645ce7e3..0000000000000
--- a/clang/test/CodeGenHLSL/intrinsic.ll
+++ /dev/null
@@ -1,19 +0,0 @@
-; Function Attrs: alwaysinline mustprogress nofree norecurse nosync nounwind willreturn memory(none)
-define hidden spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(half noundef nofpclass(nan inf) %I, half noundef nofpclass(nan inf) %N, half noundef nofpclass(nan inf) %ETA) local_unnamed_addr #0 {
-entry:
-  %mul.i = fmul reassoc nnan ninf nsz arcp afn half %N, %I
-  %mul1.i = fmul reassoc nnan ninf nsz arcp afn half %ETA, %ETA
-  %mul2.i = fmul reassoc nnan ninf nsz arcp afn half %mul.i, %mul.i
-  %sub.i = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, %mul2.i
-  %mul3.i = fmul reassoc nnan ninf nsz arcp afn half %mul1.i, %sub.i
-  %sub4.i = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, %mul3.i
-  %mul5.i = fmul reassoc nnan ninf nsz arcp afn half %ETA, %I
-  %mul6.i = fmul reassoc nnan ninf nsz arcp afn half %ETA, %mul.i
-  %0 = tail call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half %sub4.i)
-  %add.i = fadd reassoc nnan ninf nsz arcp afn half %0, %mul6.i
-  %mul7.i = fmul reassoc nnan ninf nsz arcp afn half %add.i, %N
-  %sub8.i = fsub reassoc nnan ninf nsz arcp afn half %mul5.i, %mul7.i
-  %cmp.i = fcmp reassoc nnan ninf nsz arcp afn olt half %sub4.i, 0xH0000
-  %hlsl.select.i = select reassoc nnan ninf nsz arcp afn i1 %cmp.i, half 0xH0000, half %sub8.i
-  ret half %hlsl.select.i
-}
\ No newline at end of file

>From 322bb3cefe540c239d715e3006ab10dfabe765cb Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Tue, 8 Jul 2025 17:05:44 -0700
Subject: [PATCH 21/23] SPV attributes

---
 clang/test/CodeGenHLSL/builtins/refract.hlsl | 12 ++++++------
 clang/test/CodeGenSPIRV/Builtins/refract.c   |  6 +++---
 llvm/include/llvm/IR/IntrinsicsSPIRV.td      |  6 +++---
 3 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
index 5fab2f988e7c5..1e1184bf626b5 100644
--- a/clang/test/CodeGenHLSL/builtins/refract.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -70,7 +70,7 @@ half test_refract_half(half I, half N, half ETA) {
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_Dh(
 // SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, half %{{.*}})
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, half %{{.*}})
 // SPVCHECK:    ret <2 x half> [[SPV_REFRACT_I]]
 //
 half2 test_refract_half2(half2 I, half2 N, half ETA) {
@@ -100,7 +100,7 @@ half2 test_refract_half2(half2 I, half2 N, half ETA) {
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
 // SPVCHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <3 x half> @llvm.spv.refract.v3f16.f16(<3 x half> %{{.*}}, <3 x half> %{{.*}}, half %{{.*}})
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <3 x half> @llvm.spv.refract.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}}, half %{{.*}})
 // SPVCHECK:    ret <3 x half> [[SPV_REFRACT_I]]
 //
 half3 test_refract_half3(half3 I, half3 N, half ETA) {
@@ -130,7 +130,7 @@ half3 test_refract_half3(half3 I, half3 N, half ETA) {
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <4 x half> @_Z18test_refract_half4Dv4_DhS_Dh(
 // SPVCHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[I:%.*]], <4 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <4 x half> @llvm.spv.refract.v4f16.f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, half %{{.*}})
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <4 x half> @llvm.spv.refract.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, half %{{.*}})
 // SPVCHECK:    ret <4 x half> [[SPV_REFRACT_I]]
 //
 half4 test_refract_half4(half4 I, half4 N, half ETA) {
@@ -203,7 +203,7 @@ float test_refract_float(float I, float N, float ETA) {
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <2 x float> @_Z19test_refract_float2Dv2_fS_f(
 // SPVCHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[I:%.*]], <2 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <2 x float> @llvm.spv.refract.v2f32.f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, float %{{.*}})
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <2 x float> @llvm.spv.refract.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, float %{{.*}})
 // SPVCHECK:    ret <2 x float> [[SPV_REFRACT_I]]
 //
 float2 test_refract_float2(float2 I, float2 N, float ETA) {
@@ -233,7 +233,7 @@ float2 test_refract_float2(float2 I, float2 N, float ETA) {
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <3 x float> @_Z19test_refract_float3Dv3_fS_f(
 // SPVCHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[I:%.*]], <3 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <3 x float> @llvm.spv.refract.v3f32.f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, float %{{.*}})
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <3 x float> @llvm.spv.refract.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, float %{{.*}})
 // SPVCHECK:    ret <3 x float> [[SPV_REFRACT_I]]
 //
 float3 test_refract_float3(float3 I, float3 N, float ETA) {
@@ -263,7 +263,7 @@ float3 test_refract_float3(float3 I, float3 N, float ETA) {
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <4 x float> @_Z19test_refract_float4Dv4_fS_f(
 // SPVCHECK-SAME: <4 x float> noundef nofpclass(nan inf) %{{.*}}, <4 x float> noundef nofpclass(nan inf) %{{.*}}, float noundef nofpclass(nan inf) %{{.*}}) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32.f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, float %{{.*}})
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, float %{{.*}})
 // SPVCHECK:    ret <4 x float> [[SPV_REFRACT_I]]
 //
 float4 test_refract_float4(float4 I, float4 N, float ETA) {
diff --git a/clang/test/CodeGenSPIRV/Builtins/refract.c b/clang/test/CodeGenSPIRV/Builtins/refract.c
index f477f532ffb6f..08256006edec4 100644
--- a/clang/test/CodeGenSPIRV/Builtins/refract.c
+++ b/clang/test/CodeGenSPIRV/Builtins/refract.c
@@ -7,7 +7,7 @@ typedef float float4 __attribute__((ext_vector_type(4)));
 // CHECK-LABEL: define spir_func <2 x float> @test_refract_float2(
 // CHECK-SAME: <2 x float> noundef [[I:%.*]], <2 x float> noundef [[N:%.*]], float noundef [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK:    [[SPV_REFRACT:%.*]] = tail call <2 x float> @llvm.spv.refract.v2f32.f32(<2 x float> [[I]], <2 x float> [[N]], float [[ETA]])
+// CHECK:    [[SPV_REFRACT:%.*]] = tail call <2 x float> @llvm.spv.refract.v2f32(<2 x float> [[I]], <2 x float> [[N]], float [[ETA]])
 // CHECK-NEXT:    ret <2 x float> [[SPV_REFRACT]]
 //
 float2 test_refract_float2(float2 I, float2 N, float eta) { return __builtin_spirv_refract(I, N, eta); }
@@ -15,7 +15,7 @@ float2 test_refract_float2(float2 I, float2 N, float eta) { return __builtin_spi
 // CHECK-LABEL: define spir_func <3 x float> @test_refract_float3(
 // CHECK-SAME: <3 x float> noundef [[I:%.*]], <3 x float> noundef [[N:%.*]], float noundef [[ETA:%.*]]) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <3 x float> @llvm.spv.refract.v3f32.f32(<3 x float> [[I]], <3 x float> [[N]], float [[ETA]])
+// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <3 x float> @llvm.spv.refract.v3f32(<3 x float> [[I]], <3 x float> [[N]], float [[ETA]])
 // CHECK-NEXT:    ret <3 x float> [[SPV_REFRACT]]
 //
 float3 test_refract_float3(float3 I, float3 N, float eta) { return __builtin_spirv_refract(I, N, eta); }
@@ -23,7 +23,7 @@ float3 test_refract_float3(float3 I, float3 N, float eta) { return __builtin_spi
 // CHECK-LABEL: define spir_func <4 x float> @test_refract_float4(
 // CHECK-SAME: <4 x float> noundef [[I:%.*]], <4 x float> noundef [[N:%.*]], float noundef [[ETA:%.*]]) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <4 x float> @llvm.spv.refract.v4f32.f32(<4 x float> [[I]], <4 x float> [[N]], float [[ETA]])
+// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <4 x float> @llvm.spv.refract.v4f32(<4 x float> [[I]], <4 x float> [[N]], float [[ETA]])
 // CHECK-NEXT:    ret <4 x float> [[SPV_REFRACT]]
 //
 float4 test_refract_float4(float4 I, float4 N, float eta) { return __builtin_spirv_refract(I, N, eta); }
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 241fe59ee8620..d24e68959df60 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -75,13 +75,13 @@ let TargetPrefix = "spv" in {
     [IntrNoMem] >;
   def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
   def int_spv_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
+  def int_spv_reflect : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
   def int_spv_refract
       : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
                               [llvm_anyfloat_ty, LLVMMatchType<0>,
-                               llvm_anyfloat_ty],
+                              LLVMVectorElementType<0>],
                               [IntrNoMem]>;
-  def int_spv_reflect : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
-  def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
+def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
   def int_spv_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
   def int_spv_smoothstep : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
   def int_spv_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [LLVMMatchType<0>, llvm_anyfloat_ty], [IntrNoMem]>;

>From bfdbc692b70d6b2f2583802b3860ffa79bb1fe23 Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Wed, 9 Jul 2025 13:50:15 -0700
Subject: [PATCH 22/23] fix pipeline Lit tests failure for ubuntu

---
 clang/lib/Sema/SemaHLSL.cpp                   |  2 +-
 clang/lib/Sema/SemaSPIRV.cpp                  | 22 +------------------
 .../CodeGen/SPIRV/hlsl-intrinsics/refract.ll  |  2 +-
 3 files changed, 3 insertions(+), 23 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 505b11f722d7b..bad357b50929b 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3995,4 +3995,4 @@ bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
   }
   Init = C;
   return true;
-}
\ No newline at end of file
+}
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 4b99983e823a2..8790e1abdb5bd 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -63,16 +63,6 @@ static bool CheckAllArgTypesAreCorrect(
   return false;
 }
 
-static bool CheckAllArgTypesAreCorrect(
-    Sema *S, CallExpr *TheCall,
-    llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
-  return CheckAllArgTypesAreCorrect(
-      S, TheCall,
-      SmallVector<
-          llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>, 4>(
-          TheCall->getNumArgs(), Check));
-}
-
 static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
                                            int ArgOrdinal,
                                            clang::QualType PassedType) {
@@ -90,9 +80,7 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
 static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
                                                  int ArgOrdinal,
                                                  clang::QualType PassedType) {
-  const auto *VecTy = PassedType->getAs<VectorType>();
-
-  if (VecTy || (!PassedType->isHalfType() && !PassedType->isFloat32Type()))
+  if (!PassedType->isHalfType() && !PassedType->isFloat32Type())
     return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
            << ArgOrdinal << /* scalar */ 1 << /* no int */ 0
            << /* half or float */ 2 << PassedType;
@@ -300,14 +288,6 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
                                    llvm::ArrayRef(ChecksArr)))
       return true;
 
-    ExprResult C = TheCall->getArg(2);
-    QualType ArgTyC = C.get()->getType();
-    if (!ArgTyC->isFloatingType()) {
-      SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
-          << 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1 << ArgTyC;
-      return true;
-    }
-
     QualType RetTy = TheCall->getArg(0)->getType();
     TheCall->setType(RetTy);
     break;
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
index 58c86d9bc6aef..8642410b8493a 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
@@ -1,5 +1,5 @@
 ; RUN: llc -O0 -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
-; RUN: %if spirv-tools %{ llc -O0 -spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; Make sure SPIRV operation function calls for refract are lowered correctly.
 

>From 2f9769ba23507010c7ebcfd0d4ca6518a5f2cc5f Mon Sep 17 00:00:00 2001
From: Anagha Rajendra Rao <anagrao at microsoft.com>
Date: Mon, 14 Jul 2025 10:30:10 -0700
Subject: [PATCH 23/23] enable spirv scalar implementation

---
 clang/lib/CodeGen/TargetBuiltins/SPIR.cpp     |  3 --
 clang/lib/Headers/hlsl/hlsl_detail.h          |  8 ----
 .../lib/Headers/hlsl/hlsl_intrinsic_helpers.h |  3 +-
 clang/lib/Sema/SemaSPIRV.cpp                  | 21 +++++++++
 clang/test/CodeGenHLSL/builtins/refract.hlsl  | 47 ++++---------------
 clang/test/CodeGenSPIRV/Builtins/refract.c    |  6 +--
 .../SemaHLSL/BuiltIns/refract-errors.hlsl     |  4 --
 .../test/SemaSPIRV/BuiltIns/refract-errors.c  | 18 +++++++
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  2 +-
 .../CodeGen/SPIRV/hlsl-intrinsics/refract.ll  |  2 +-
 10 files changed, 55 insertions(+), 59 deletions(-)

diff --git a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
index 1c63e04f757c7..10abb682679fe 100644
--- a/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
@@ -66,9 +66,6 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
            E->getArg(1)->getType()->hasFloatingRepresentation() &&
            E->getArg(2)->getType()->isFloatingType() &&
            "refract operands must have a float representation");
-    assert(E->getArg(0)->getType()->isVectorType() &&
-           E->getArg(1)->getType()->isVectorType() &&
-           "refract I and N operands must be a vector");
     return Builder.CreateIntrinsic(
         /*ReturnType=*/I->getType(), Intrinsic::spv_refract,
         ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
diff --git a/clang/lib/Headers/hlsl/hlsl_detail.h b/clang/lib/Headers/hlsl/hlsl_detail.h
index 96e101a1e3aa8..80c4900121dfb 100644
--- a/clang/lib/Headers/hlsl/hlsl_detail.h
+++ b/clang/lib/Headers/hlsl/hlsl_detail.h
@@ -45,14 +45,6 @@ template <typename T> struct is_arithmetic {
   static const bool Value = __is_arithmetic(T);
 };
 
-template <typename T> struct is_vector {
-  static const bool value = false;
-};
-
-template <typename T, int N> struct is_vector<vector<T, N>> {
-  static const bool value = true;
-};
-
 template <typename T, int N>
 using HLSL_FIXED_VECTOR =
     vector<__detail::enable_if_t<(N > 1 && N <= 4), T>, N>;
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
index 2ded062ea0d27..e8ccccb489815 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
@@ -73,8 +73,7 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
 
 template <typename T, typename U> constexpr T refract_impl(T I, T N, U Eta) {
 #if (__has_builtin(__builtin_spirv_refract))
-  if (is_vector<T>::value)
-    return __builtin_spirv_refract(I, N, Eta);
+  return __builtin_spirv_refract(I, N, Eta);
 #endif
   T Mul = dot(N, I);
   T K = 1 - Eta * Eta * (1 - Mul * Mul);
diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp
index 8790e1abdb5bd..87f3b0d8cbed6 100644
--- a/clang/lib/Sema/SemaSPIRV.cpp
+++ b/clang/lib/Sema/SemaSPIRV.cpp
@@ -287,6 +287,27 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
     if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
                                    llvm::ArrayRef(ChecksArr)))
       return true;
+    // Check that first two arguments are vectors of the same type
+    QualType Arg0Type = TheCall->getArg(0)->getType();
+    if (!SemaRef.getASTContext().hasSameUnqualifiedType(
+            Arg0Type, TheCall->getArg(1)->getType()))
+      return SemaRef.Diag(TheCall->getBeginLoc(),
+                          diag::err_vec_builtin_incompatible_vector)
+             << TheCall->getDirectCallee() << /* first two */ 0
+             << SourceRange(TheCall->getArg(0)->getBeginLoc(),
+                            TheCall->getArg(1)->getEndLoc());
+
+    // Check that scalar type of 3rd arg is same as base type of first two args
+    clang::QualType BaseType =
+        Arg0Type->isVectorType()
+            ? Arg0Type->castAs<clang::VectorType>()->getElementType()
+            : Arg0Type;
+    if (!SemaRef.getASTContext().hasSameUnqualifiedType(
+            BaseType, TheCall->getArg(2)->getType()))
+      return SemaRef.Diag(TheCall->getBeginLoc(),
+                          diag::err_hlsl_builtin_scalar_vector_mismatch)
+             << /* all */ 0 << TheCall->getDirectCallee() << Arg0Type
+             << TheCall->getArg(2)->getType();
 
     QualType RetTy = TheCall->getArg(0)->getType();
     TheCall->setType(RetTy);
diff --git a/clang/test/CodeGenHLSL/builtins/refract.hlsl b/clang/test/CodeGenHLSL/builtins/refract.hlsl
index 1e1184bf626b5..eda256451ee2b 100644
--- a/clang/test/CodeGenHLSL/builtins/refract.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/refract.hlsl
@@ -27,21 +27,8 @@
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) half @_Z17test_refract_halfDhDhDh(
 // SPVCHECK-SAME: half noundef nofpclass(nan inf) [[I:%.*]], half noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] 
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
-// SPVCHECK:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
-// SPVCHECK:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
-// SPVCHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL2_I]]
-// SPVCHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[MUL1_I]], [[SUB_I]]
-// SPVCHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL3_I]]
-// SPVCHECK:    [[MUL5_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
-// SPVCHECK:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half %{{.*}}, %{{.*}}
-// SPVCHECK:    [[TMP18:%.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.sqrt.f16(half %{{.*}})
-// SPVCHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn half [[MUL6_I]], [[TMP18]]
-// SPVCHECK:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[ADD_I]], %{{.*}}
-// SPVCHECK:    [[SUB8_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half [[MUL5_I]], [[MUL7_I]]
-// SPVCHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt half %{{.*}}, 0xH0000
-// SPVCHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CMP_I]], half 0xH0000, half %{{.*}}
-// SPVCHECK:    ret half [[HLSL_SELECT_I]]
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef half @llvm.spv.refract.f16.f16(half %{{.*}}, half %{{.*}}, half %{{.*}})
+// SPVCHECK:    ret half [[SPV_REFRACT_I]]
 //
 half test_refract_half(half I, half N, half ETA) {
     return refract(I, N, ETA);
@@ -70,7 +57,7 @@ half test_refract_half(half I, half N, half ETA) {
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_Dh(
 // SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, half %{{.*}})
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, half %{{.*}})
 // SPVCHECK:    ret <2 x half> [[SPV_REFRACT_I]]
 //
 half2 test_refract_half2(half2 I, half2 N, half ETA) {
@@ -100,7 +87,7 @@ half2 test_refract_half2(half2 I, half2 N, half ETA) {
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <3 x half> @_Z18test_refract_half3Dv3_DhS_Dh(
 // SPVCHECK-SAME: <3 x half> noundef nofpclass(nan inf) [[I:%.*]], <3 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <3 x half> @llvm.spv.refract.v3f16(<3 x half> %{{.*}}, <3 x half> %{{.*}}, half %{{.*}})
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <3 x half> @llvm.spv.refract.v3f16.f16(<3 x half> %{{.*}}, <3 x half> %{{.*}}, half %{{.*}})
 // SPVCHECK:    ret <3 x half> [[SPV_REFRACT_I]]
 //
 half3 test_refract_half3(half3 I, half3 N, half ETA) {
@@ -130,7 +117,7 @@ half3 test_refract_half3(half3 I, half3 N, half ETA) {
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <4 x half> @_Z18test_refract_half4Dv4_DhS_Dh(
 // SPVCHECK-SAME: <4 x half> noundef nofpclass(nan inf) [[I:%.*]], <4 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <4 x half> @llvm.spv.refract.v4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, half %{{.*}})
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <4 x half> @llvm.spv.refract.v4f16.f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, half %{{.*}})
 // SPVCHECK:    ret <4 x half> [[SPV_REFRACT_I]]
 //
 half4 test_refract_half4(half4 I, half4 N, half ETA) {
@@ -156,25 +143,11 @@ half4 test_refract_half4(half4 I, half4 N, half ETA) {
 // CHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CMP_I]], float 0.000000e+00, float %{{.*}}
 // CHECK:    ret float [[HLSL_SELECT_I]]
 //
-//
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) float @_Z18test_refract_floatfff(
 // SPVCHECK-SAME: float noundef nofpclass(nan inf) [[I:%.*]], float noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
-// SPVCHECK:    [[MUL1_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
-// SPVCHECK:    [[MUL2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
-// SPVCHECK:    [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL2_I]]
-// SPVCHECK:    [[MUL3_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[MUL1_I]], [[SUB_I]]
-// SPVCHECK:    [[SUB4_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float 1.000000e+00, [[MUL3_I]]
-// SPVCHECK:    [[MUL5_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
-// SPVCHECK:    [[MUL6_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float %{{.*}}, %{{.*}}
-// SPVCHECK:    [[TMP18:%.*]] = call reassoc nnan ninf nsz arcp afn float @llvm.sqrt.f32(float %{{.*}})
-// SPVCHECK:    [[ADD_I:%.*]] = fadd reassoc nnan ninf nsz arcp afn float [[MUL6_I]], [[TMP18]]
-// SPVCHECK:    [[MUL7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn float [[ADD_I]], %{{.*}}
-// SPVCHECK:    [[SUB8_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn float [[MUL5_I]], [[MUL7_I]]
-// SPVCHECK:    [[CMP_I:%.*]] = fcmp reassoc nnan ninf nsz arcp afn olt float %{{.*}}, 0.000000e+00
-// SPVCHECK:    [[HLSL_SELECT_I:%.*]] = select reassoc nnan ninf nsz arcp afn i1 [[CMP_I]], float 0.000000e+00, float %{{.*}}
-// SPVCHECK:    ret float [[HLSL_SELECT_I]]
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef float @llvm.spv.refract.f32.f32(float %{{.*}}, float  %{{.*}}, float %{{.*}})
+// SPVCHECK:    ret float [[SPV_REFRACT_I]]
 //
 float test_refract_float(float I, float N, float ETA) {
     return refract(I, N, ETA);
@@ -203,7 +176,7 @@ float test_refract_float(float I, float N, float ETA) {
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <2 x float> @_Z19test_refract_float2Dv2_fS_f(
 // SPVCHECK-SAME: <2 x float> noundef nofpclass(nan inf) [[I:%.*]], <2 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <2 x float> @llvm.spv.refract.v2f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, float %{{.*}})
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <2 x float> @llvm.spv.refract.v2f32.f32(<2 x float> %{{.*}}, <2 x float> %{{.*}}, float %{{.*}})
 // SPVCHECK:    ret <2 x float> [[SPV_REFRACT_I]]
 //
 float2 test_refract_float2(float2 I, float2 N, float ETA) {
@@ -233,7 +206,7 @@ float2 test_refract_float2(float2 I, float2 N, float ETA) {
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <3 x float> @_Z19test_refract_float3Dv3_fS_f(
 // SPVCHECK-SAME: <3 x float> noundef nofpclass(nan inf) [[I:%.*]], <3 x float> noundef nofpclass(nan inf) [[N:%.*]], float noundef nofpclass(nan inf) [[ETA:%.*]]) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <3 x float> @llvm.spv.refract.v3f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, float %{{.*}})
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <3 x float> @llvm.spv.refract.v3f32.f32(<3 x float> %{{.*}}, <3 x float> %{{.*}}, float %{{.*}})
 // SPVCHECK:    ret <3 x float> [[SPV_REFRACT_I]]
 //
 float3 test_refract_float3(float3 I, float3 N, float ETA) {
@@ -263,7 +236,7 @@ float3 test_refract_float3(float3 I, float3 N, float ETA) {
 // SPVCHECK-LABEL: define hidden spir_func noundef nofpclass(nan inf) <4 x float> @_Z19test_refract_float4Dv4_fS_f(
 // SPVCHECK-SAME: <4 x float> noundef nofpclass(nan inf) %{{.*}}, <4 x float> noundef nofpclass(nan inf) %{{.*}}, float noundef nofpclass(nan inf) %{{.*}}) #[[ATTR0:[0-9]+]] {
 // SPVCHECK:  [[ENTRY:.*:]]
-// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, float %{{.*}})
+// SPVCHECK:    [[SPV_REFRACT_I:%.*]] = call reassoc nnan ninf nsz arcp afn noundef <4 x float> @llvm.spv.refract.v4f32.f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, float %{{.*}})
 // SPVCHECK:    ret <4 x float> [[SPV_REFRACT_I]]
 //
 float4 test_refract_float4(float4 I, float4 N, float ETA) {
diff --git a/clang/test/CodeGenSPIRV/Builtins/refract.c b/clang/test/CodeGenSPIRV/Builtins/refract.c
index 08256006edec4..f477f532ffb6f 100644
--- a/clang/test/CodeGenSPIRV/Builtins/refract.c
+++ b/clang/test/CodeGenSPIRV/Builtins/refract.c
@@ -7,7 +7,7 @@ typedef float float4 __attribute__((ext_vector_type(4)));
 // CHECK-LABEL: define spir_func <2 x float> @test_refract_float2(
 // CHECK-SAME: <2 x float> noundef [[I:%.*]], <2 x float> noundef [[N:%.*]], float noundef [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK:    [[SPV_REFRACT:%.*]] = tail call <2 x float> @llvm.spv.refract.v2f32(<2 x float> [[I]], <2 x float> [[N]], float [[ETA]])
+// CHECK:    [[SPV_REFRACT:%.*]] = tail call <2 x float> @llvm.spv.refract.v2f32.f32(<2 x float> [[I]], <2 x float> [[N]], float [[ETA]])
 // CHECK-NEXT:    ret <2 x float> [[SPV_REFRACT]]
 //
 float2 test_refract_float2(float2 I, float2 N, float eta) { return __builtin_spirv_refract(I, N, eta); }
@@ -15,7 +15,7 @@ float2 test_refract_float2(float2 I, float2 N, float eta) { return __builtin_spi
 // CHECK-LABEL: define spir_func <3 x float> @test_refract_float3(
 // CHECK-SAME: <3 x float> noundef [[I:%.*]], <3 x float> noundef [[N:%.*]], float noundef [[ETA:%.*]]) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <3 x float> @llvm.spv.refract.v3f32(<3 x float> [[I]], <3 x float> [[N]], float [[ETA]])
+// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <3 x float> @llvm.spv.refract.v3f32.f32(<3 x float> [[I]], <3 x float> [[N]], float [[ETA]])
 // CHECK-NEXT:    ret <3 x float> [[SPV_REFRACT]]
 //
 float3 test_refract_float3(float3 I, float3 N, float eta) { return __builtin_spirv_refract(I, N, eta); }
@@ -23,7 +23,7 @@ float3 test_refract_float3(float3 I, float3 N, float eta) { return __builtin_spi
 // CHECK-LABEL: define spir_func <4 x float> @test_refract_float4(
 // CHECK-SAME: <4 x float> noundef [[I:%.*]], <4 x float> noundef [[N:%.*]], float noundef [[ETA:%.*]]) local_unnamed_addr #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]
-// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <4 x float> @llvm.spv.refract.v4f32(<4 x float> [[I]], <4 x float> [[N]], float [[ETA]])
+// CHECK-NEXT:    [[SPV_REFRACT:%.*]] = tail call <4 x float> @llvm.spv.refract.v4f32.f32(<4 x float> [[I]], <4 x float> [[N]], float [[ETA]])
 // CHECK-NEXT:    ret <4 x float> [[SPV_REFRACT]]
 //
 float4 test_refract_float4(float4 I, float4 N, float eta) { return __builtin_spirv_refract(I, N, eta); }
diff --git a/clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
index fe094591bea6a..791cb8a60b44c 100644
--- a/clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl
@@ -58,10 +58,6 @@ float3 test_mixed_datatype_inputs(float3 p0, float3 p1, half p2) {
   return refract(p0, p1, p2);
 }
 
-half3 test_mixed_datatype_inputs(half3 p0, half3 p1, float p2) {
-  return refract(p0, p1, p2);
-}
-
 typedef float float5 __attribute__((ext_vector_type(5)));
 
 float5 test_vec5_inputs(float5 p0, float5 p1,  float p2) {
diff --git a/clang/test/SemaSPIRV/BuiltIns/refract-errors.c b/clang/test/SemaSPIRV/BuiltIns/refract-errors.c
index 775f82a858dc2..72b2b685ef15c 100644
--- a/clang/test/SemaSPIRV/BuiltIns/refract-errors.c
+++ b/clang/test/SemaSPIRV/BuiltIns/refract-errors.c
@@ -1,6 +1,9 @@
 // RUN: %clang_cc1 %s -triple spirv-pc-vulkan-compute -verify
 
 typedef float float2 __attribute__((ext_vector_type(2)));
+typedef float float3 __attribute__((ext_vector_type(3)));
+typedef _Float16 half;
+typedef half half2 __attribute__((ext_vector_type(2)));
 
 float2 test_no_third_arg(float2 p0) {
   return __builtin_spirv_refract(p0, p0);
@@ -21,3 +24,18 @@ float test_int_scalar_inputs(int p0, int p1, int p2) {
   return __builtin_spirv_refract(p0, p1, p2);
   //  expected-error at -1 {{1st argument must be a scalar or vector of 16 or 32 bit floating-point types (was 'int')}}
 }
+
+float test_float_and_half_inputs(float2 p0, half2 p1, float p2) {
+  return __builtin_spirv_refract(p0, p1, p2);
+  //  expected-error at -1 {{2nd argument must be a scalar or vector of 16 or 32 bit floating-point types (was 'half2' (vector of 2 'half' values))}}
+}
+
+float test_float_and_half_2_inputs(float2 p0, float2 p1, half p2) {
+  return __builtin_spirv_refract(p0, p1, p2);
+  //  expected-error at -1 {{3rd argument must be a scalar 16 or 32 bit floating-point type (was 'half' (aka '_Float16'))}}
+}
+
+float2 test_mismatch_vector_size_inputs(float2 p0, float3 p1, float p2) {
+  return __builtin_spirv_refract(p0, p1, p2);
+  //  expected-error at -1 {{first two arguments to '__builtin_spirv_refract' must have the same type}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index d24e68959df60..d9577e7af4bac 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -79,7 +79,7 @@ let TargetPrefix = "spv" in {
   def int_spv_refract
       : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
                               [llvm_anyfloat_ty, LLVMMatchType<0>,
-                              LLVMVectorElementType<0>],
+                              llvm_anyfloat_ty],
                               [IntrNoMem]>;
 def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
   def int_spv_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
index 8642410b8493a..b18e929568534 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll
@@ -1,5 +1,5 @@
 ; RUN: llc -O0 -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
 
 ; Make sure SPIRV operation function calls for refract are lowered correctly.
 



More information about the llvm-commits mailing list