[clang] [HLSL] Allow truncation to scalar (PR #104844)

Chris B via cfe-commits cfe-commits at lists.llvm.org
Tue Sep 10 06:48:35 PDT 2024


https://github.com/llvm-beanz updated https://github.com/llvm/llvm-project/pull/104844

>From 1a1a92aff834aa2f6f12d3de001714d8338dd274 Mon Sep 17 00:00:00 2001
From: Chris Bieneman <chris.bieneman at me.com>
Date: Tue, 13 Aug 2024 15:51:34 -0500
Subject: [PATCH 1/4] [HLSL] Allow truncation to scalar

HLSL allows implicit conversions to truncate vectors to scalar
pr-values. These conversions are scored as vector truncations and
should warn appropriately.

This change allows forming a truncation cast to a pr-value, but not an
l-value. Truncating a vector to a scalar is performed by loading the
first element of the vector and disregarding the remaining elements.

Fixes #102964
---
 clang/lib/CodeGen/CGExprScalar.cpp            | 17 ++++---
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  5 +-
 clang/lib/Sema/SemaExprCXX.cpp                | 49 ++++++++++--------
 clang/lib/Sema/SemaOverload.cpp               | 50 ++++++++++++-------
 .../standard_conversion_sequences.hlsl        | 24 +++++++++
 clang/test/CodeGenHLSL/builtins/dot.hlsl      | 16 ------
 clang/test/CodeGenHLSL/builtins/lerp.hlsl     | 21 --------
 clang/test/CodeGenHLSL/builtins/mad.hlsl      | 18 -------
 .../TruncationOverloadResolution.hlsl         | 42 ++++++++++++++--
 .../BuiltinVector/ScalarSwizzleErrors.hlsl    |  6 ++-
 10 files changed, 141 insertions(+), 107 deletions(-)

diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 84392745ea6144..2ccdb840579a98 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2692,14 +2692,19 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     return CGF.CGM.createOpenCLIntToSamplerConversion(E, CGF);
 
   case CK_HLSLVectorTruncation: {
-    assert(DestTy->isVectorType() && "Expected dest type to be vector type");
+    assert((DestTy->isVectorType() || DestTy->isBuiltinType()) &&
+           "Destination type must be a vector or builtin type.");
     Value *Vec = Visit(const_cast<Expr *>(E));
-    SmallVector<int, 16> Mask;
-    unsigned NumElts = DestTy->castAs<VectorType>()->getNumElements();
-    for (unsigned I = 0; I != NumElts; ++I)
-      Mask.push_back(I);
+    if (auto *VecTy = DestTy->getAs<VectorType>()) {
+      SmallVector<int, 16> Mask;
+      unsigned NumElts = VecTy->getNumElements();
+      for (unsigned I = 0; I != NumElts; ++I)
+        Mask.push_back(I);
 
-    return Builder.CreateShuffleVector(Vec, Mask, "trunc");
+      return Builder.CreateShuffleVector(Vec, Mask, "trunc");
+    }
+    llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
+    return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
   }
 
   } // end of switch
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 678cdc77f8a71b..be65e149515b9f 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -673,9 +673,6 @@ float dot(float3, float3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
 float dot(float4, float4);
 
-_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
-double dot(double, double);
-
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
 int dot(int, int);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
@@ -916,7 +913,7 @@ float4 lerp(float4, float4, float4);
 /// \brief Returns the length of the specified floating-point vector.
 /// \param x [in] The vector of floats, or a scalar float.
 ///
-/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + �).
+/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + ...).
 
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index 5356bcf172f752..e62e3bcbead3be 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -4301,8 +4301,10 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
 // from type to the elements of the to type without resizing the vector.
 static QualType adjustVectorType(ASTContext &Context, QualType FromTy,
                                  QualType ToType, QualType *ElTy = nullptr) {
-  auto *ToVec = ToType->castAs<VectorType>();
-  QualType ElType = ToVec->getElementType();
+  QualType ElType = ToType;
+  if (auto *ToVec = ToType->getAs<VectorType>())
+    ElType = ToVec->getElementType();
+
   if (ElTy)
     *ElTy = ElType;
   if (!FromTy->isVectorType())
@@ -4463,7 +4465,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
   case ICK_Integral_Conversion: {
     QualType ElTy = ToType;
     QualType StepTy = ToType;
-    if (ToType->isVectorType())
+    if (FromType->isVectorType() || ToType->isVectorType())
       StepTy = adjustVectorType(Context, FromType, ToType, &ElTy);
     if (ElTy->isBooleanType()) {
       assert(FromType->castAs<EnumType>()->getDecl()->isFixed() &&
@@ -4483,7 +4485,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
   case ICK_Floating_Promotion:
   case ICK_Floating_Conversion: {
     QualType StepTy = ToType;
-    if (ToType->isVectorType())
+    if (FromType->isVectorType() || ToType->isVectorType())
       StepTy = adjustVectorType(Context, FromType, ToType);
     From = ImpCastExprToType(From, StepTy, CK_FloatingCast, VK_PRValue,
                              /*BasePath=*/nullptr, CCK)
@@ -4515,7 +4517,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
   case ICK_Floating_Integral: {
     QualType ElTy = ToType;
     QualType StepTy = ToType;
-    if (ToType->isVectorType())
+    if (FromType->isVectorType() || ToType->isVectorType())
       StepTy = adjustVectorType(Context, FromType, ToType, &ElTy);
     if (ElTy->isRealFloatingType())
       From = ImpCastExprToType(From, StepTy, CK_IntegralToFloating, VK_PRValue,
@@ -4656,11 +4658,11 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
     }
     QualType ElTy = FromType;
     QualType StepTy = ToType;
-    if (FromType->isVectorType()) {
-      if (getLangOpts().HLSL)
-        StepTy = adjustVectorType(Context, FromType, ToType);
+    if (FromType->isVectorType())
       ElTy = FromType->castAs<VectorType>()->getElementType();
-    }
+    if (getLangOpts().HLSL &&
+        (FromType->isVectorType() || ToType->isVectorType()))
+      StepTy = adjustVectorType(Context, FromType, ToType);
 
     From = ImpCastExprToType(From, StepTy, ScalarTypeToBooleanCastKind(ElTy),
                              VK_PRValue,
@@ -4815,8 +4817,8 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
     // TODO: Support HLSL matrices.
     assert((!From->getType()->isMatrixType() && !ToType->isMatrixType()) &&
            "Dimension conversion for matrix types is not implemented yet.");
-    assert(ToType->isVectorType() &&
-           "Dimension conversion is only supported for vector types.");
+    assert((ToType->isVectorType() || ToType->isBuiltinType()) &&
+           "Dimension conversion output must be vector or scalar type.");
     switch (SCS.Dimension) {
     case ICK_HLSL_Vector_Splat: {
       // Vector splat from any arithmetic type to a vector.
@@ -4828,18 +4830,23 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
     }
     case ICK_HLSL_Vector_Truncation: {
       // Note: HLSL built-in vectors are ExtVectors. Since this truncates a
-      // vector to a smaller vector, this can only operate on arguments where
-      // the source and destination types are ExtVectors.
-      assert(From->getType()->isExtVectorType() && ToType->isExtVectorType() &&
-             "HLSL vector truncation should only apply to ExtVectors");
+      // vector to a smaller vector or to a scalar, this can only operate on
+      // arguments where the source type is an ExtVector and the destination
+      // type is destination type is either an ExtVectorType or a builtin scalar
+      // type.
       auto *FromVec = From->getType()->castAs<VectorType>();
-      auto *ToVec = ToType->castAs<VectorType>();
       QualType ElType = FromVec->getElementType();
-      QualType TruncTy =
-          Context.getExtVectorType(ElType, ToVec->getNumElements());
-      From = ImpCastExprToType(From, TruncTy, CK_HLSLVectorTruncation,
-                               From->getValueKind())
-                 .get();
+      if (auto *ToVec = ToType->getAs<VectorType>()) {
+        QualType TruncTy =
+            Context.getExtVectorType(ElType, ToVec->getNumElements());
+        From = ImpCastExprToType(From, TruncTy, CK_HLSLVectorTruncation,
+                                 From->getValueKind())
+                   .get();
+      } else {
+        From = ImpCastExprToType(From, ElType, CK_HLSLVectorTruncation,
+                                 From->getValueKind())
+                   .get();
+      }
       break;
     }
     case ICK_Identity:
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 52f640eb96b73b..f6a097305564a0 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -2032,26 +2032,42 @@ static bool IsVectorConversion(Sema &S, QualType FromType, QualType ToType,
   if (S.Context.hasSameUnqualifiedType(FromType, ToType))
     return false;
 
+  // HLSL allows implicit truncation of vector types.
+  if (S.getLangOpts().HLSL) {
+    auto *ToExtType = ToType->getAs<ExtVectorType>();
+    auto *FromExtType = FromType->getAs<ExtVectorType>();
+
+    // If both arguments are vectors, handle possible vector truncation and
+    // element conversion.
+    if (ToExtType && FromExtType) {
+      unsigned FromElts = FromExtType->getNumElements();
+      unsigned ToElts = ToExtType->getNumElements();
+      if (FromElts < ToElts)
+        return false;
+      if (FromElts == ToElts)
+        ElConv = ICK_Identity;
+      else
+        ElConv = ICK_HLSL_Vector_Truncation;
+
+      QualType FromElTy = FromExtType->getElementType();
+      QualType ToElTy = ToExtType->getElementType();
+      if (S.Context.hasSameUnqualifiedType(FromElTy, ToElTy))
+        return true;
+      return IsVectorElementConversion(S, FromElTy, ToElTy, ICK, From);
+    }
+    if (FromExtType && nullptr == ToExtType) {
+      ElConv = ICK_HLSL_Vector_Truncation;
+      QualType FromElTy = FromExtType->getElementType();
+      if (S.Context.hasSameUnqualifiedType(FromElTy, ToType))
+        return true;
+      return IsVectorElementConversion(S, FromElTy, ToType, ICK, From);
+    }
+    // Fallthrough for the case where ToType is a vector and FromType is not.
+  }
+
   // There are no conversions between extended vector types, only identity.
   if (auto *ToExtType = ToType->getAs<ExtVectorType>()) {
     if (auto *FromExtType = FromType->getAs<ExtVectorType>()) {
-      // HLSL allows implicit truncation of vector types.
-      if (S.getLangOpts().HLSL) {
-        unsigned FromElts = FromExtType->getNumElements();
-        unsigned ToElts = ToExtType->getNumElements();
-        if (FromElts < ToElts)
-          return false;
-        if (FromElts == ToElts)
-          ElConv = ICK_Identity;
-        else
-          ElConv = ICK_HLSL_Vector_Truncation;
-
-        QualType FromElTy = FromExtType->getElementType();
-        QualType ToElTy = ToExtType->getElementType();
-        if (S.Context.hasSameUnqualifiedType(FromElTy, ToElTy))
-          return true;
-        return IsVectorElementConversion(S, FromElTy, ToElTy, ICK, From);
-      }
       // There are no conversions between extended vector types other than the
       // identity conversion.
       return false;
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl
index 5d751be6dae066..6478ea67e32a0d 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/standard_conversion_sequences.hlsl
@@ -117,3 +117,27 @@ void d4_to_b2() {
   vector<double,4> d4 = 9.0;
   vector<bool, 2> b2 = d4;
 }
+
+// CHECK-LABEL: d4_to_d1
+// CHECK: [[d4:%.*]] = alloca <4 x double>
+// CHECK: [[d1:%.*]] = alloca <1 x double>
+// CHECK: store <4 x double> <double 9.000000e+00, double 9.000000e+00, double 9.000000e+00, double 9.000000e+00>, ptr [[d4]]
+// CHECK: [[vecd4:%.*]] = load <4 x double>, ptr [[d4]]
+// CHECK: [[vecd1:%.*]] = shufflevector <4 x double> [[vecd4]], <4 x double> poison, <1 x i32> zeroinitializer
+// CHECK: store <1 x double> [[vecd1]], ptr [[d1:%.*]], align 8
+void d4_to_d1() {
+  vector<double,4> d4 = 9.0;
+  vector<double,1> d1 = d4;
+}
+
+// CHECK-LABEL: d4_to_dScalar
+// CHECK: [[d4:%.*]] = alloca <4 x double>
+// CHECK: [[d:%.*]] = alloca double
+// CHECK: store <4 x double> <double 9.000000e+00, double 9.000000e+00, double 9.000000e+00, double 9.000000e+00>, ptr [[d4]]
+// CHECK: [[vecd4:%.*]] = load <4 x double>, ptr [[d4]]
+// CHECK: [[d4x:%.*]] = extractelement <4 x double> [[vecd4]], i32 0
+// CHECK: store double [[d4x]], ptr [[d]]
+void d4_to_dScalar() {
+  vector<double,4> d4 = 9.0;
+  double d = d4;
+}
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index ae6e45c3f9482a..51ce5a88b302b9 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -143,19 +143,3 @@ float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
 // CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1)
 // CHECK: ret float %dx.dot
 float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
-
-// CHECK:  %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1)
-// CHECK: ret float %dx.dot
-float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); }
-
-// CHECK:  %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1)
-// CHECK: ret float %dx.dot
-float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); }
-
-// CHECK:  %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1)
-// CHECK: ret float %dx.dot
-float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); }
-
-// CHECK: %dx.dot = fmul double %0, %1
-// CHECK: ret double %dx.dot
-double test_dot_double(double p0, double p1) { return dot(p0, p1); }
diff --git a/clang/test/CodeGenHLSL/builtins/lerp.hlsl b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
index 53ac24dd456930..4c0cbe8a671a9b 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp.hlsl
@@ -65,24 +65,3 @@ float3 test_lerp_float3(float3 p0) { return lerp(p0, p0, p0); }
 // SPIR_CHECK: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}})
 // CHECK: ret <4 x float> %hlsl.lerp
 float4 test_lerp_float4(float4 p0) { return lerp(p0, p0, p0); }
-
-// CHECK: %[[b:.*]] = load <2 x float>, ptr %p1.addr, align 8
-// CHECK: %[[c:.*]] = load <2 x float>, ptr %p1.addr, align 8
-// DXIL_CHECK: %hlsl.lerp = call <2 x float> @llvm.dx.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %[[b]], <2 x float> %[[c]])
-// SPIR_CHECK: %hlsl.lerp = call <2 x float> @llvm.spv.lerp.v2f32(<2 x float> %splat.splat, <2 x float> %[[b]], <2 x float> %[[c]])
-// CHECK: ret <2 x float> %hlsl.lerp
-float2 test_lerp_float2_splat(float p0, float2 p1) { return lerp(p0, p1, p1); }
-
-// CHECK: %[[b:.*]] = load <3 x float>, ptr %p1.addr, align 16
-// CHECK: %[[c:.*]] = load <3 x float>, ptr %p1.addr, align 16
-// DXIL_CHECK: %hlsl.lerp = call <3 x float> @llvm.dx.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %[[b]], <3 x float> %[[c]])
-// SPIR_CHECK: %hlsl.lerp = call <3 x float> @llvm.spv.lerp.v3f32(<3 x float> %splat.splat, <3 x float> %[[b]], <3 x float> %[[c]])
-// CHECK: ret <3 x float> %hlsl.lerp
-float3 test_lerp_float3_splat(float p0, float3 p1) { return lerp(p0, p1, p1); }
-
-// CHECK: %[[b:.*]] = load <4 x float>, ptr %p1.addr, align 16
-// CHECK: %[[c:.*]] = load <4 x float>, ptr %p1.addr, align 16
-// DXIL_CHECK: %hlsl.lerp = call <4 x float> @llvm.dx.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %[[b]], <4 x float> %[[c]])
-// SPIR_CHECK: %hlsl.lerp = call <4 x float> @llvm.spv.lerp.v4f32(<4 x float> %splat.splat, <4 x float> %[[b]], <4 x float> %[[c]])
-// CHECK:  ret <4 x float> %hlsl.lerp
-float4 test_lerp_float4_splat(float p0, float4 p1) { return lerp(p0, p1, p1); }
diff --git a/clang/test/CodeGenHLSL/builtins/mad.hlsl b/clang/test/CodeGenHLSL/builtins/mad.hlsl
index 449a793caf93b7..265a2552c80fb4 100644
--- a/clang/test/CodeGenHLSL/builtins/mad.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/mad.hlsl
@@ -263,21 +263,3 @@ uint64_t3 test_mad_uint64_t3(uint64_t3 p0, uint64_t3 p1, uint64_t3 p2) { return
 // SPIR_CHECK: mul nuw <4 x i64>  %{{.*}}, %{{.*}}
 // SPIR_CHECK: add nuw <4 x i64>  %{{.*}}, %{{.*}}
 uint64_t4 test_mad_uint64_t4(uint64_t4 p0, uint64_t4 p1, uint64_t4 p2) { return mad(p0, p1, p2); }
-
-// CHECK: %[[p1:.*]] = load <2 x float>, ptr %p1.addr, align 8
-// CHECK: %[[p2:.*]] = load <2 x float>, ptr %p2.addr, align 8
-// CHECK: %hlsl.fmad = call <2 x float>  @llvm.fmuladd.v2f32(<2 x float> %splat.splat, <2 x float> %[[p1]], <2 x float> %[[p2]])
-// CHECK: ret <2 x float> %hlsl.fmad
-float2 test_mad_float2_splat(float p0, float2 p1, float2 p2) { return mad(p0, p1, p2); }
-
-// CHECK: %[[p1:.*]] = load <3 x float>, ptr %p1.addr, align 16
-// CHECK: %[[p2:.*]] = load <3 x float>, ptr %p2.addr, align 16
-// CHECK: %hlsl.fmad = call <3 x float>  @llvm.fmuladd.v3f32(<3 x float> %splat.splat, <3 x float> %[[p1]], <3 x float> %[[p2]])
-// CHECK: ret <3 x float> %hlsl.fmad
-float3 test_mad_float3_splat(float p0, float3 p1, float3 p2) { return mad(p0, p1, p2); }
-
-// CHECK: %[[p1:.*]] = load <4 x float>, ptr %p1.addr, align 16
-// CHECK: %[[p2:.*]] = load <4 x float>, ptr %p2.addr, align 16
-// CHECK:  %hlsl.fmad = call <4 x float>  @llvm.fmuladd.v4f32(<4 x float> %splat.splat, <4 x float> %[[p1]], <4 x float> %[[p2]])
-// CHECK:  ret <4 x float> %hlsl.fmad
-float4 test_mad_float4_splat(float p0, float4 p1, float4 p2) { return mad(p0, p1, p2); }
diff --git a/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl b/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl
index f8cfe22372e885..cb00d2e226b7cf 100644
--- a/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl
+++ b/clang/test/SemaHLSL/TruncationOverloadResolution.hlsl
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -fnative-half-type -finclude-default-header -fsyntax-only %s -DERROR=1 -verify
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -fnative-half-type -finclude-default-header -Wconversion -fsyntax-only %s -DERROR=1 -verify
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -fnative-half-type -finclude-default-header -ast-dump %s | FileCheck %s
 
 // Case 1: Prefer exact-match truncation over conversion.
@@ -32,6 +32,42 @@ void Case2(float4 F) {
   Half2Double2(F); // expected-warning{{implicit conversion truncates vector: 'float4' (aka 'vector<float, 4>') to 'vector<double, 2>' (vector of 2 'double' values)}}
 }
 
+// Case 3: Allow truncation down to vector<T,1> or T.
+void Half(half H);
+void Float(float F);
+void Double(double D);
+
+void Half1(half1 H);
+void Float1(float1 F);
+void Double1(double1 D);
+
+void Case3(half3 H, float3 F, double3 D) {
+  Half(H); // expected-warning{{implicit conversion turns vector to scalar: 'half3' (aka 'vector<half, 3>') to 'half'}}
+  Half(F); // expected-warning{{implicit conversion turns vector to scalar: 'float3' (aka 'vector<float, 3>') to 'half'}}
+  Half(D); // expected-warning{{implicit conversion turns vector to scalar: 'double3' (aka 'vector<double, 3>') to 'half'}}
+
+  Float(H); // expected-warning{{implicit conversion turns vector to scalar: 'half3' (aka 'vector<half, 3>') to 'float'}}
+  Float(F); // expected-warning{{implicit conversion turns vector to scalar: 'float3' (aka 'vector<float, 3>') to 'float'}}
+  Float(D); // expected-warning{{implicit conversion turns vector to scalar: 'double3' (aka 'vector<double, 3>') to 'float'}}
+
+  Double(H); // expected-warning{{implicit conversion turns vector to scalar: 'half3' (aka 'vector<half, 3>') to 'double'}}
+  Double(F); // expected-warning{{implicit conversion turns vector to scalar: 'float3' (aka 'vector<float, 3>') to 'double'}}
+  Double(D); // expected-warning{{implicit conversion turns vector to scalar: 'double3' (aka 'vector<double, 3>') to 'double'}}
+
+  Half1(H); // expected-warning{{implicit conversion truncates vector: 'half3' (aka 'vector<half, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}}
+  Half1(F); // expected-warning{{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}} expected-warning{{implicit conversion loses floating-point precision: 'float3' (aka 'vector<float, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}}
+  Half1(D); // expected-warning{{implicit conversion truncates vector: 'double3' (aka 'vector<double, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}} expected-warning{{implicit conversion loses floating-point precision: 'double3' (aka 'vector<double, 3>') to 'vector<half, 1>' (vector of 1 'half' value)}}
+
+  Float1(H); // expected-warning{{implicit conversion truncates vector: 'half3' (aka 'vector<half, 3>') to 'vector<float, 1>' (vector of 1 'float' value)}}
+  Float1(F); // expected-warning{{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<float, 1>' (vector of 1 'float' value)}}
+  Float1(D); // expected-warning{{implicit conversion truncates vector: 'double3' (aka 'vector<double, 3>') to 'vector<float, 1>' (vector of 1 'float' value)}} expected-warning{{implicit conversion loses floating-point precision: 'double3' (aka 'vector<double, 3>') to 'vector<float, 1>' (vector of 1 'float' value)}}
+
+  Double1(H); // expected-warning{{implicit conversion truncates vector: 'half3' (aka 'vector<half, 3>') to 'vector<double, 1>' (vector of 1 'double' value)}}
+  Double1(F); // expected-warning{{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<double, 1>' (vector of 1 'double' value)}}
+  Double1(D); // expected-warning{{implicit conversion truncates vector: 'double3' (aka 'vector<double, 3>') to 'vector<double, 1>' (vector of 1 'double' value)}}
+}
+
+
 #if ERROR
 // Case 3: Two promotions or two conversions are ambiguous.
 void Float2Double2(double2 D); // expected-note{{candidate function}}
@@ -55,8 +91,8 @@ void Case1(half4 H, float4 F, double4 D) {
   Half2Half3(F); // expected-error {{call to 'Half2Half3' is ambiguous}}
   Half2Half3(D); // expected-error {{call to 'Half2Half3' is ambiguous}}
   Half2Half3(H.xyz);
-  Half2Half3(F.xyz);
-  Half2Half3(D.xyz);
+  Half2Half3(F.xyz); // expected-warning{{implicit conversion loses floating-point precision: 'vector<float, 3>' (vector of 3 'float' values) to 'vector<half, 3>' (vector of 3 'half' values)}}
+  Half2Half3(D.xyz); // expected-warning{{implicit conversion loses floating-point precision: 'vector<double, 3>' (vector of 3 'double' values) to 'vector<half, 3>' (vector of 3 'half' values)}}
 
   Double2Double3(H); // expected-error {{call to 'Double2Double3' is ambiguous}}
   Double2Double3(F); // expected-error {{call to 'Double2Double3' is ambiguous}}
diff --git a/clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzleErrors.hlsl b/clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzleErrors.hlsl
index 5088991f2e28ac..b1c75acbc16c6f 100644
--- a/clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzleErrors.hlsl
+++ b/clang/test/SemaHLSL/Types/BuiltinVector/ScalarSwizzleErrors.hlsl
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library  -x hlsl -finclude-default-header -verify %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -verify %s
 
 int2 ToTwoInts(int V) {
   return V.xy; // expected-error{{vector component access exceeds type 'vector<int, 1>' (vector of 1 'int' value)}}
@@ -16,6 +16,10 @@ float2 WhatIsHappening(float V) {
   return V.; // expected-error{{expected unqualified-id}}
 }
 
+float ScalarLValue(float2 V) {
+  (float)V = 4.0; // expected-error{{assignment to cast is illegal, lvalue casts are not supported}}
+}
+
 // These cases produce no error.
 
 float2 HowManyFloats(float V) {

>From 7a6e1ad4a3456d1cce71963d60df644f536722c6 Mon Sep 17 00:00:00 2001
From: Chris Bieneman <chris.bieneman at me.com>
Date: Mon, 9 Sep 2024 15:36:40 -0500
Subject: [PATCH 2/4] Re-add dot(double,double) overload

---
 clang/lib/Headers/hlsl/hlsl_intrinsics.h | 3 +++
 clang/test/CodeGenHLSL/builtins/dot.hlsl | 4 ++++
 2 files changed, 7 insertions(+)

diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 711507458267a2..2ac18056b0fc3d 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -673,6 +673,9 @@ float dot(float3, float3);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
 float dot(float4, float4);
 
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+double dot(double, double);
+
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
 int dot(int, int);
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
index 16d1c4e52cfe83..3f6be04a595e23 100644
--- a/clang/test/CodeGenHLSL/builtins/dot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -154,3 +154,7 @@ float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); }
 // CHECK: %hlsl.dot = call float @llvm.[[ICF]].fdot.v4f32(<4 x float>
 // CHECK: ret float %hlsl.dot
 float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); }
+
+// CHECK: %hlsl.dot = fmul double
+// CHECK: ret double %hlsl.dot
+double test_dot_double(double p0, double p1) { return dot(p0, p1); }

>From 6521f2d7f9f937fedebd905c3a27fe5c49c79bd7 Mon Sep 17 00:00:00 2001
From: Chris Bieneman <chris.bieneman at me.com>
Date: Mon, 9 Sep 2024 16:41:08 -0500
Subject: [PATCH 3/4] Update based on PR feedback from @bogner

---
 clang/lib/CodeGen/CGExprScalar.cpp |  2 +-
 clang/lib/Sema/SemaExprCXX.cpp     | 19 +++++++------------
 clang/lib/Sema/SemaOverload.cpp    |  2 +-
 3 files changed, 9 insertions(+), 14 deletions(-)

diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 9a19bced519882..057e35a752dd65 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2709,7 +2709,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
            "Destination type must be a vector or builtin type.");
     Value *Vec = Visit(const_cast<Expr *>(E));
     if (auto *VecTy = DestTy->getAs<VectorType>()) {
-      SmallVector<int, 16> Mask;
+      SmallVector<int> Mask;
       unsigned NumElts = VecTy->getNumElements();
       for (unsigned I = 0; I != NumElts; ++I)
         Mask.push_back(I);
diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index f380280f070114..e8df2105fd605a 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -4848,18 +4848,13 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
       // type is destination type is either an ExtVectorType or a builtin scalar
       // type.
       auto *FromVec = From->getType()->castAs<VectorType>();
-      QualType ElType = FromVec->getElementType();
-      if (auto *ToVec = ToType->getAs<VectorType>()) {
-        QualType TruncTy =
-            Context.getExtVectorType(ElType, ToVec->getNumElements());
-        From = ImpCastExprToType(From, TruncTy, CK_HLSLVectorTruncation,
-                                 From->getValueKind())
-                   .get();
-      } else {
-        From = ImpCastExprToType(From, ElType, CK_HLSLVectorTruncation,
-                                 From->getValueKind())
-                   .get();
-      }
+      QualType TruncTy = FromVec->getElementType();
+      if (auto *ToVec = ToType->getAs<VectorType>())
+        TruncTy = Context.getExtVectorType(TruncTy, ToVec->getNumElements());
+      From = ImpCastExprToType(From, TruncTy, CK_HLSLVectorTruncation,
+                               From->getValueKind())
+                 .get();
+
       break;
     }
     case ICK_Identity:
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 035ead76d96450..ea72d3f003cbc4 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -2055,7 +2055,7 @@ static bool IsVectorConversion(Sema &S, QualType FromType, QualType ToType,
         return true;
       return IsVectorElementConversion(S, FromElTy, ToElTy, ICK, From);
     }
-    if (FromExtType && nullptr == ToExtType) {
+    if (FromExtType && !ToExtType) {
       ElConv = ICK_HLSL_Vector_Truncation;
       QualType FromElTy = FromExtType->getElementType();
       if (S.Context.hasSameUnqualifiedType(FromElTy, ToType))

>From 1403c86eb24362350403e13dc23944a39ee11bb0 Mon Sep 17 00:00:00 2001
From: Chris Bieneman <chris.bieneman at me.com>
Date: Tue, 10 Sep 2024 08:43:12 -0500
Subject: [PATCH 4/4] Fix constant expression evaluation

This fixes the truncation cast constant evaluation for vectors,
integers and floats.
---
 clang/lib/AST/Expr.cpp                        |  2 +-
 clang/lib/AST/ExprConstant.cpp                | 22 ++++++++++++++++++-
 .../BuiltinVector/TruncationConstantExpr.hlsl | 20 +++++++++++++++++
 3 files changed, 42 insertions(+), 2 deletions(-)
 create mode 100644 clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl

diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index 6545912ed160d9..e10142eff8ec47 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -1924,7 +1924,6 @@ bool CastExpr::CastConsistency() const {
   case CK_FixedPointToIntegral:
   case CK_IntegralToFixedPoint:
   case CK_MatrixCast:
-  case CK_HLSLVectorTruncation:
     assert(!getType()->isBooleanType() && "unheralded conversion to bool");
     goto CheckNoBasePath;
 
@@ -1945,6 +1944,7 @@ bool CastExpr::CastConsistency() const {
   case CK_BuiltinFnToFnPtr:
   case CK_FixedPointToBoolean:
   case CK_HLSLArrayRValue:
+  case CK_HLSLVectorTruncation:
   CheckNoBasePath:
     assert(path_empty() && "Cast kind should not have a base path!");
     break;
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 78d25006360042..6387e375dda79c 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -10935,6 +10935,15 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) {
 
     return true;
   }
+  case CK_HLSLVectorTruncation: {
+    APValue Val;
+    SmallVector<APValue, 4> Elements;
+    if (!EvaluateVector(SE, Val, Info))
+      return Error(E);
+    for (unsigned I = 0; I < NElts; I++)
+      Elements.push_back(Val.getVectorElt(I));
+    return Success(Elements, E);
+  }
   default:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
   }
@@ -14478,7 +14487,6 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_FixedPointCast:
   case CK_IntegralToFixedPoint:
   case CK_MatrixCast:
-  case CK_HLSLVectorTruncation:
     llvm_unreachable("invalid cast kind for integral value");
 
   case CK_BitCast:
@@ -14651,6 +14659,12 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
       return false;
     return Success(Value, E);
   }
+  case CK_HLSLVectorTruncation: {
+    APValue Val;
+    if (!EvaluateVector(SubExpr, Val, Info))
+      return Error(E);
+    return Success(Val.getVectorElt(0), E);
+  }
   }
 
   llvm_unreachable("unknown cast resulting in integral value");
@@ -15177,6 +15191,12 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr *E) {
     Result = V.getComplexFloatReal();
     return true;
   }
+  case CK_HLSLVectorTruncation: {
+    APValue Val;
+    if (!EvaluateVector(SubExpr, Val, Info))
+      return Error(E);
+    return Success(Val.getVectorElt(0), E);
+  }
   }
 }
 
diff --git a/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
new file mode 100644
index 00000000000000..918daa03d80322
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
@@ -0,0 +1,20 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -verify %s
+
+// expected-no-diagnostics
+
+// Note: these tests are a bit awkward because at time of writing we don't have a
+// good way to constexpr `any` for bool vector conditions, and the condition for
+// _Static_assert must be an integral constant.
+export void fn() {
+  // This compiling successfully verifies that the vector constant expression
+  // gets truncated to an integer at compile time for instantiation.
+  _Static_assert(((int)1.xxxx) + 0 == 1, "Woo!");
+
+  // This compiling successfully verifies that the vector constant expression
+  // gets truncated to a float at compile time for instantiation.
+  _Static_assert(((float)1.0.xxxx) + 0.0 == 1.0, "Woo!");
+
+  // This compiling successfully verifies that a vector can be truncated to a
+  // smaller vector, then truncated to a float as a constant expression.
+  _Static_assert(((float2)float4(6, 5, 4, 3)).x == 6, "Woo!");
+}



More information about the cfe-commits mailing list