[clang] 8709bca - clang: Add __builtin_elementwise_fma

Matt Arsenault via cfe-commits cfe-commits at lists.llvm.org
Fri Feb 24 17:55:17 PST 2023


Author: Matt Arsenault
Date: 2023-02-24T21:55:08-04:00
New Revision: 8709bcacfb3a06847b47bb6b47e8556db43f3a43

URL: https://github.com/llvm/llvm-project/commit/8709bcacfb3a06847b47bb6b47e8556db43f3a43
DIFF: https://github.com/llvm/llvm-project/commit/8709bcacfb3a06847b47bb6b47e8556db43f3a43.diff

LOG: clang: Add  __builtin_elementwise_fma

I didn't understand why the other builtins have promotion logic,
or how it would apply for a ternary operation. Implicit conversions
are evil to begin with,  and even more so when the purpose is to get
an exact IR intrinsic. This checks all the arguments have the same type.

Added: 
    

Modified: 
    clang/docs/LanguageExtensions.rst
    clang/include/clang/Basic/Builtins.def
    clang/include/clang/Sema/Sema.h
    clang/lib/CodeGen/CGBuiltin.cpp
    clang/lib/Sema/SemaChecking.cpp
    clang/test/CodeGen/builtins-elementwise-math.c
    clang/test/Sema/builtins-elementwise-math.c

Removed: 
    


################################################################################
diff  --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst
index 2595b5de8f265..c0ea8afad6cb2 100644
--- a/clang/docs/LanguageExtensions.rst
+++ b/clang/docs/LanguageExtensions.rst
@@ -631,6 +631,7 @@ Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±in
 =========================================== ================================================================ =========================================
  T __builtin_elementwise_abs(T x)            return the absolute value of a number x; the absolute value of   signed integer and floating point types
                                              the most negative integer remains the most negative integer
+ T __builtin_elementwise_fma(T x, T y, T z)  fused multiply add, (x * y) +  z.                                              floating point types
  T __builtin_elementwise_ceil(T x)           return the smallest integral value greater than or equal to x    floating point types
  T __builtin_elementwise_sin(T x)            return the sine of x interpreted as an angle in radians          floating point types
  T __builtin_elementwise_cos(T x)            return the cosine of x interpreted as an angle in radians        floating point types

diff  --git a/clang/include/clang/Basic/Builtins.def b/clang/include/clang/Basic/Builtins.def
index 41288410786b0..6db599a3de116 100644
--- a/clang/include/clang/Basic/Builtins.def
+++ b/clang/include/clang/Basic/Builtins.def
@@ -671,6 +671,7 @@ BUILTIN(__builtin_elementwise_sin, "v.", "nct")
 BUILTIN(__builtin_elementwise_trunc, "v.", "nct")
 BUILTIN(__builtin_elementwise_canonicalize, "v.", "nct")
 BUILTIN(__builtin_elementwise_copysign, "v.", "nct")
+BUILTIN(__builtin_elementwise_fma, "v.", "nct")
 BUILTIN(__builtin_elementwise_add_sat, "v.", "nct")
 BUILTIN(__builtin_elementwise_sub_sat, "v.", "nct")
 BUILTIN(__builtin_reduce_max, "v.", "nct")

diff  --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 41691eab4972b..0c6a3887e4151 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -13531,6 +13531,7 @@ class Sema final {
   bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc);
 
   bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
+  bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
   bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
   bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall);
 

diff  --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 52ec6e092c449..1535b14c7fb40 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -3118,6 +3118,8 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
         emitUnaryBuiltin(*this, E, llvm::Intrinsic::canonicalize, "elt.trunc"));
   case Builtin::BI__builtin_elementwise_copysign:
     return RValue::get(emitBinaryBuiltin(*this, E, llvm::Intrinsic::copysign));
+  case Builtin::BI__builtin_elementwise_fma:
+    return RValue::get(emitTernaryBuiltin(*this, E, llvm::Intrinsic::fma));
   case Builtin::BI__builtin_elementwise_add_sat:
   case Builtin::BI__builtin_elementwise_sub_sat: {
     Value *Op0 = EmitScalarExpr(E->getArg(0));

diff  --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index eded6061c77eb..485351f157fde 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2626,20 +2626,16 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
       return ExprError();
 
     QualType ArgTy = TheCall->getArg(0)->getType();
-    QualType EltTy = ArgTy;
-
-    if (auto *VecTy = EltTy->getAs<VectorType>())
-      EltTy = VecTy->getElementType();
-    if (!EltTy->isFloatingType()) {
-      Diag(TheCall->getArg(0)->getBeginLoc(),
-           diag::err_builtin_invalid_arg_type)
-          << 1 << /* float ty*/ 5 << ArgTy;
-
+    if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(),
+                                      ArgTy, 1))
+      return ExprError();
+    break;
+  }
+  case Builtin::BI__builtin_elementwise_fma: {
+    if (SemaBuiltinElementwiseTernaryMath(TheCall))
       return ExprError();
-    }
     break;
   }
-
   // These builtins restrict the element type to integer
   // types only.
   case Builtin::BI__builtin_elementwise_add_sat:
@@ -17877,6 +17873,40 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
   return false;
 }
 
+bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) {
+  if (checkArgCount(*this, TheCall, 3))
+    return true;
+
+  Expr *Args[3];
+  for (int I = 0; I < 3; ++I) {
+    ExprResult Converted = UsualUnaryConversions(TheCall->getArg(I));
+    if (Converted.isInvalid())
+      return true;
+    Args[I] = Converted.get();
+  }
+
+  int ArgOrdinal = 1;
+  for (Expr *Arg : Args) {
+    if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
+                                      ArgOrdinal++))
+      return true;
+  }
+
+  for (int I = 1; I < 3; ++I) {
+    if (Args[0]->getType().getCanonicalType() !=
+        Args[I]->getType().getCanonicalType()) {
+      return Diag(Args[0]->getBeginLoc(),
+                  diag::err_typecheck_call_
diff erent_arg_types)
+             << Args[0]->getType() << Args[I]->getType();
+    }
+
+    TheCall->setArg(I, Args[I]);
+  }
+
+  TheCall->setType(Args[0]->getType());
+  return false;
+}
+
 bool Sema::PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall) {
   if (checkArgCount(*this, TheCall, 1))
     return true;

diff  --git a/clang/test/CodeGen/builtins-elementwise-math.c b/clang/test/CodeGen/builtins-elementwise-math.c
index 1571d2bb7f650..1b48a12b92056 100644
--- a/clang/test/CodeGen/builtins-elementwise-math.c
+++ b/clang/test/CodeGen/builtins-elementwise-math.c
@@ -1,5 +1,9 @@
 // RUN: %clang_cc1 -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
 
+typedef _Float16 half;
+
+typedef half half2 __attribute__((ext_vector_type(2)));
+typedef float float2 __attribute__((ext_vector_type(2)));
 typedef float float4 __attribute__((ext_vector_type(4)));
 typedef short int si8 __attribute__((ext_vector_type(8)));
 typedef unsigned int u4 __attribute__((ext_vector_type(4)));
@@ -525,3 +529,77 @@ void test_builtin_elementwise_copysign(float f1, float f2, double d1, double d2,
   // CHECK-NEXT: call <2 x double> @llvm.copysign.v2f64(<2 x double> <double 1.000000e+00, double 1.000000e+00>, <2 x double> [[V2F64]])
   v2f64 = __builtin_elementwise_copysign((double2)1.0, v2f64);
 }
+
+void test_builtin_elementwise_fma(float f32, double f64,
+                                  float2 v2f32, float4 v4f32,
+                                  double2 v2f64, double3 v3f64,
+                                  const float4 c_v4f32,
+                                  half f16, half2 v2f16) {
+  // CHECK-LABEL: define void @test_builtin_elementwise_fma(
+  // CHECK:      [[F32_0:%.+]] = load float, ptr %f32.addr
+  // CHECK-NEXT: [[F32_1:%.+]] = load float, ptr %f32.addr
+  // CHECK-NEXT: [[F32_2:%.+]] = load float, ptr %f32.addr
+  // CHECK-NEXT: call float @llvm.fma.f32(float [[F32_0]], float [[F32_1]], float [[F32_2]])
+  float f2 = __builtin_elementwise_fma(f32, f32, f32);
+
+  // CHECK:      [[F64_0:%.+]] = load double, ptr %f64.addr
+  // CHECK-NEXT: [[F64_1:%.+]] = load double, ptr %f64.addr
+  // CHECK-NEXT: [[F64_2:%.+]] = load double, ptr %f64.addr
+  // CHECK-NEXT: call double @llvm.fma.f64(double [[F64_0]], double [[F64_1]], double [[F64_2]])
+  double d2 = __builtin_elementwise_fma(f64, f64, f64);
+
+  // CHECK:      [[V4F32_0:%.+]] = load <4 x float>, ptr %v4f32.addr
+  // CHECK-NEXT: [[V4F32_1:%.+]] = load <4 x float>, ptr %v4f32.addr
+  // CHECK-NEXT: [[V4F32_2:%.+]] = load <4 x float>, ptr %v4f32.addr
+  // CHECK-NEXT: call <4 x float> @llvm.fma.v4f32(<4 x float> [[V4F32_0]], <4 x float> [[V4F32_1]], <4 x float> [[V4F32_2]])
+  float4 tmp_v4f32 = __builtin_elementwise_fma(v4f32, v4f32, v4f32);
+
+
+  // FIXME: Are we really still doing the 3 vector load workaround
+  // CHECK:      [[V3F64_LOAD_0:%.+]] = load <4 x double>, ptr %v3f64.addr
+  // CHECK-NEXT: [[V3F64_0:%.+]] = shufflevector
+  // CHECK-NEXT: [[V3F64_LOAD_1:%.+]] = load <4 x double>, ptr %v3f64.addr
+  // CHECK-NEXT: [[V3F64_1:%.+]] = shufflevector
+  // CHECK-NEXT: [[V3F64_LOAD_2:%.+]] = load <4 x double>, ptr %v3f64.addr
+  // CHECK-NEXT: [[V3F64_2:%.+]] = shufflevector
+    // CHECK-NEXT: call <3 x double> @llvm.fma.v3f64(<3 x double> [[V3F64_0]], <3 x double> [[V3F64_1]], <3 x double> [[V3F64_2]])
+  v3f64 = __builtin_elementwise_fma(v3f64, v3f64, v3f64);
+
+  // CHECK:      [[F64_0:%.+]] = load double, ptr %f64.addr
+  // CHECK-NEXT: [[F64_1:%.+]] = load double, ptr %f64.addr
+  // CHECK-NEXT: [[F64_2:%.+]] = load double, ptr %f64.addr
+  // CHECK-NEXT: call double @llvm.fma.f64(double [[F64_0]], double [[F64_1]], double [[F64_2]])
+  v2f64 = __builtin_elementwise_fma(f64, f64, f64);
+
+  // CHECK:      [[V4F32_0:%.+]] = load <4 x float>, ptr %c_v4f32.addr
+  // CHECK-NEXT: [[V4F32_1:%.+]] = load <4 x float>, ptr %c_v4f32.addr
+  // CHECK-NEXT: [[V4F32_2:%.+]] = load <4 x float>, ptr %c_v4f32.addr
+  // CHECK-NEXT: call <4 x float> @llvm.fma.v4f32(<4 x float> [[V4F32_0]], <4 x float> [[V4F32_1]], <4 x float> [[V4F32_2]])
+  v4f32 = __builtin_elementwise_fma(c_v4f32, c_v4f32, c_v4f32);
+
+  // CHECK:      [[F16_0:%.+]] = load half, ptr %f16.addr
+  // CHECK-NEXT: [[F16_1:%.+]] = load half, ptr %f16.addr
+  // CHECK-NEXT: [[F16_2:%.+]] = load half, ptr %f16.addr
+  // CHECK-NEXT: call half @llvm.fma.f16(half [[F16_0]], half [[F16_1]], half [[F16_2]])
+  half tmp_f16 = __builtin_elementwise_fma(f16, f16, f16);
+
+  // CHECK:      [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr
+  // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr
+  // CHECK-NEXT: [[V2F16_2:%.+]] = load <2 x half>, ptr %v2f16.addr
+  // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> [[V2F16_2]])
+  half2 tmp0_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, v2f16);
+
+  // CHECK:      [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr
+  // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr
+  // CHECK-NEXT: [[F16_2:%.+]] = load half, ptr %f16.addr
+  // CHECK-NEXT: [[V2F16_2_INSERT:%.+]] = insertelement
+  // CHECK-NEXT: [[V2F16_2:%.+]] = shufflevector <2 x half> [[V2F16_2_INSERT]], <2 x half> poison, <2 x i32> zeroinitializer
+  // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> [[V2F16_2]])
+  half2 tmp1_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, (half2)f16);
+
+  // CHECK:      [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr
+  // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr
+  // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> <half 0xH4400, half 0xH4400>)
+  half2 tmp2_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, (half2)4.0);
+
+}

diff  --git a/clang/test/Sema/builtins-elementwise-math.c b/clang/test/Sema/builtins-elementwise-math.c
index cb8b79739b6d2..c803fcea0d505 100644
--- a/clang/test/Sema/builtins-elementwise-math.c
+++ b/clang/test/Sema/builtins-elementwise-math.c
@@ -4,6 +4,8 @@ typedef double double2 __attribute__((ext_vector_type(2)));
 typedef double double4 __attribute__((ext_vector_type(4)));
 typedef float float2 __attribute__((ext_vector_type(2)));
 typedef float float4 __attribute__((ext_vector_type(4)));
+
+typedef int int2 __attribute__((ext_vector_type(2)));
 typedef int int3 __attribute__((ext_vector_type(3)));
 typedef unsigned unsigned3 __attribute__((ext_vector_type(3)));
 typedef unsigned unsigned4 __attribute__((ext_vector_type(4)));
@@ -572,3 +574,84 @@ void test_builtin_elementwise_copysign(int i, short s, double d, float f, float4
   float2 tmp9 = __builtin_elementwise_copysign(v4f32, v4f32);
   // expected-error at -1 {{initializing 'float2' (vector of 2 'float' values) with an expression of incompatible type 'float4' (vector of 4 'float' values)}}
 }
+
+void test_builtin_elementwise_fma(int i32, int2 v2i32, short i16,
+                                  double f64, double2 v2f64, double2 v3f64,
+                                  float f32, float2 v2f32, float v3f32, float4 v4f32,
+                                  const float4 c_v4f32,
+                                  int3 v3i32, int *ptr) {
+
+  f32 = __builtin_elementwise_fma();
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 0}}
+
+  f32 = __builtin_elementwise_fma(f32);
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 1}}
+
+  f32 = __builtin_elementwise_fma(f32, f32);
+  // expected-error at -1 {{too few arguments to function call, expected 3, have 2}}
+
+  f32 = __builtin_elementwise_fma(f32, f32, f32, f32);
+  // expected-error at -1 {{too many arguments to function call, expected 3, have 4}}
+
+  f32 = __builtin_elementwise_fma(f64, f32, f32);
+  // expected-error at -1 {{arguments are of 
diff erent types ('double' vs 'float')}}
+
+  f32 = __builtin_elementwise_fma(f32, f64, f32);
+  // expected-error at -1 {{arguments are of 
diff erent types ('float' vs 'double')}}
+
+  f32 = __builtin_elementwise_fma(f32, f32, f64);
+  // expected-error at -1 {{arguments are of 
diff erent types ('float' vs 'double')}}
+
+  f32 = __builtin_elementwise_fma(f32, f32, f64);
+  // expected-error at -1 {{arguments are of 
diff erent types ('float' vs 'double')}}
+
+  f64 = __builtin_elementwise_fma(f64, f32, f32);
+  // expected-error at -1 {{arguments are of 
diff erent types ('double' vs 'float')}}
+
+  f64 = __builtin_elementwise_fma(f64, f64, f32);
+  // expected-error at -1 {{arguments are of 
diff erent types ('double' vs 'float')}}
+
+  f64 = __builtin_elementwise_fma(f64, f32, f64);
+  // expected-error at -1 {{arguments are of 
diff erent types ('double' vs 'float')}}
+
+  v2f64 = __builtin_elementwise_fma(v2f32, f64, f64);
+  // expected-error at -1 {{arguments are of 
diff erent types ('float2' (vector of 2 'float' values) vs 'double'}}
+
+  v2f64 = __builtin_elementwise_fma(v2f32, v2f64, f64);
+  // expected-error at -1 {{arguments are of 
diff erent types ('float2' (vector of 2 'float' values) vs 'double2' (vector of 2 'double' values)}}
+
+  v2f64 = __builtin_elementwise_fma(v2f32, f64, v2f64);
+  // expected-error at -1 {{arguments are of 
diff erent types ('float2' (vector of 2 'float' values) vs 'double'}}
+
+  v2f64 = __builtin_elementwise_fma(f64, v2f32, v2f64);
+  // expected-error at -1 {{arguments are of 
diff erent types ('double' vs 'float2' (vector of 2 'float' values)}}
+
+  v2f64 = __builtin_elementwise_fma(f64, v2f64, v2f64);
+  // expected-error at -1 {{arguments are of 
diff erent types ('double' vs 'double2' (vector of 2 'double' values)}}
+
+  i32 = __builtin_elementwise_fma(i32, i32, i32);
+  // expected-error at -1 {{1st argument must be a floating point type (was 'int')}}
+
+  v2i32 = __builtin_elementwise_fma(v2i32, v2i32, v2i32);
+  // expected-error at -1 {{1st argument must be a floating point type (was 'int2' (vector of 2 'int' values))}}
+
+  f32 = __builtin_elementwise_fma(f32, f32, i32);
+  // expected-error at -1 {{3rd argument must be a floating point type (was 'int')}}
+
+  f32 = __builtin_elementwise_fma(f32, i32, f32);
+  // expected-error at -1 {{2nd argument must be a floating point type (was 'int')}}
+
+  f32 = __builtin_elementwise_fma(f32, f32, i32);
+  // expected-error at -1 {{3rd argument must be a floating point type (was 'int')}}
+
+
+  _Complex float c1, c2, c3;
+  c1 = __builtin_elementwise_fma(c1, f32, f32);
+  // expected-error at -1 {{1st argument must be a floating point type (was '_Complex float')}}
+
+  c2 = __builtin_elementwise_fma(f32, c2, f32);
+  // expected-error at -1 {{2nd argument must be a floating point type (was '_Complex float')}}
+
+  c3 = __builtin_elementwise_fma(f32, f32, c3);
+  // expected-error at -1 {{3rd argument must be a floating point type (was '_Complex float')}}
+}


        


More information about the cfe-commits mailing list