[clang] [clang] fix half && bfloat16 convert node expr codegen (PR #89051)

via cfe-commits cfe-commits at lists.llvm.org
Thu Apr 18 06:20:17 PDT 2024


https://github.com/JinjinLi868 updated https://github.com/llvm/llvm-project/pull/89051

>From 9e6c2a16172c66b7a9eec7957d95b4239f178368 Mon Sep 17 00:00:00 2001
From: Jinjin Li <lijinjin.868 at bytedance.com>
Date: Wed, 17 Apr 2024 16:44:50 +0800
Subject: [PATCH] [clang] Fix half && bfloat16 convert node expr codegen

Data type conversion between fp16 and bf16 will generate fptrunc
and fpextend nodes, but they are actually bitcast nodes.
---
 clang/lib/CodeGen/CGExprScalar.cpp            |  15 ++-
 .../test/CodeGen/X86/bfloat16-convert-half.c  | 113 ++++++++++++++++++
 2 files changed, 125 insertions(+), 3 deletions(-)
 create mode 100644 clang/test/CodeGen/X86/bfloat16-convert-half.c

diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 1f18e0d5ba409a..d4c60a2a7ffcaf 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -1431,9 +1431,15 @@ Value *ScalarExprEmitter::EmitScalarCast(Value *Src, QualType SrcType,
     return Builder.CreateFPToUI(Src, DstTy, "conv");
   }
 
-  if (DstElementTy->getTypeID() < SrcElementTy->getTypeID())
+  if ((DstElementTy->is16bitFPTy() && SrcElementTy->is16bitFPTy())) {
+    Value *FloatVal = Builder.CreateFPExt(Src, Builder.getFloatTy(), "conv");
+    // Value *Res = Builder.CreateFPTrunc(FloatVal, DstTy, "conv");
+    // return Res;
+    return Builder.CreateFPTrunc(FloatVal, DstTy, "conv");
+  } else if (DstElementTy->getTypeID() < SrcElementTy->getTypeID())
     return Builder.CreateFPTrunc(Src, DstTy, "conv");
-  return Builder.CreateFPExt(Src, DstTy, "conv");
+  else
+    return Builder.CreateFPExt(Src, DstTy, "conv");
 }
 
 /// Emit a conversion from the specified type to the specified destination type,
@@ -1906,7 +1912,10 @@ Value *ScalarExprEmitter::VisitConvertVectorExpr(ConvertVectorExpr *E) {
   } else {
     assert(SrcEltTy->isFloatingPointTy() && DstEltTy->isFloatingPointTy() &&
            "Unknown real conversion");
-    if (DstEltTy->getTypeID() < SrcEltTy->getTypeID())
+    if ((DstEltTy->is16bitFPTy() && SrcEltTy->is16bitFPTy())) {
+      Value *FloatVal = Builder.CreateFPExt(Src, Builder.getFloatTy(), "conv");
+      Res = Builder.CreateFPTrunc(FloatVal, DstTy, "conv");
+    } else if (DstEltTy->getTypeID() < SrcEltTy->getTypeID())
       Res = Builder.CreateFPTrunc(Src, DstTy, "conv");
     else
       Res = Builder.CreateFPExt(Src, DstTy, "conv");
diff --git a/clang/test/CodeGen/X86/bfloat16-convert-half.c b/clang/test/CodeGen/X86/bfloat16-convert-half.c
new file mode 100644
index 00000000000000..ad12bd3f654175
--- /dev/null
+++ b/clang/test/CodeGen/X86/bfloat16-convert-half.c
@@ -0,0 +1,113 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 4
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -target-feature +fullbf16 -S -emit-llvm %s -o - | FileCheck %s
+// CHECK-LABEL: define dso_local half @test_convert_from_bf16_to_fp16(
+// CHECK-SAME: bfloat noundef [[A:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[A_ADDR:%.*]] = alloca bfloat, align 2
+// CHECK-NEXT:    store bfloat [[A]], ptr [[A_ADDR]], align 2
+// CHECK-NEXT:    [[TMP0:%.*]] = load bfloat, ptr [[A_ADDR]], align 2
+// CHECK-NEXT:    [[CONV:%.*]] = fpext bfloat [[TMP0]] to float
+// CHECK-NEXT:    [[CONV1:%.*]] = fptrunc float [[CONV]] to half
+// CHECK-NEXT:    ret half [[CONV1]]
+//
+_Float16 test_convert_from_bf16_to_fp16(__bf16 a) {
+    return (_Float16)a;
+}
+
+// CHECK-LABEL: define dso_local bfloat @test_convert_from_fp16_to_bf16(
+// CHECK-SAME: half noundef [[A:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[A_ADDR:%.*]] = alloca half, align 2
+// CHECK-NEXT:    store half [[A]], ptr [[A_ADDR]], align 2
+// CHECK-NEXT:    [[TMP0:%.*]] = load half, ptr [[A_ADDR]], align 2
+// CHECK-NEXT:    [[CONV:%.*]] = fpext half [[TMP0]] to float
+// CHECK-NEXT:    [[CONV1:%.*]] = fptrunc float [[CONV]] to bfloat
+// CHECK-NEXT:    ret bfloat [[CONV1]]
+//
+__bf16 test_convert_from_fp16_to_bf16(_Float16 a) {
+    return (__bf16)a;
+}
+
+typedef _Float16 half2 __attribute__((ext_vector_type(2)));
+typedef _Float16 half4 __attribute__((ext_vector_type(4)));
+
+typedef __bf16 bfloat2 __attribute__((ext_vector_type(2)));
+typedef __bf16 bfloat4 __attribute__((ext_vector_type(4)));
+
+// CHECK-LABEL: define dso_local i32 @test_cast_from_fp162_to_bf162(
+// CHECK-SAME: i32 noundef [[IN_COERCE:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca <2 x bfloat>, align 4
+// CHECK-NEXT:    [[IN:%.*]] = alloca <2 x half>, align 4
+// CHECK-NEXT:    [[IN_ADDR:%.*]] = alloca <2 x half>, align 4
+// CHECK-NEXT:    store i32 [[IN_COERCE]], ptr [[IN]], align 4
+// CHECK-NEXT:    [[IN1:%.*]] = load <2 x half>, ptr [[IN]], align 4
+// CHECK-NEXT:    store <2 x half> [[IN1]], ptr [[IN_ADDR]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load <2 x half>, ptr [[IN_ADDR]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x half> [[TMP0]] to <2 x bfloat>
+// CHECK-NEXT:    store <2 x bfloat> [[TMP1]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret i32 [[TMP2]]
+//
+bfloat2 test_cast_from_fp162_to_bf162(half2 in) {
+  return (bfloat2)in;
+}
+
+
+// CHECK-LABEL: define dso_local double @test_cast_from_fp164_to_bf164(
+// CHECK-SAME: double noundef [[IN_COERCE:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca <4 x bfloat>, align 8
+// CHECK-NEXT:    [[IN:%.*]] = alloca <4 x half>, align 8
+// CHECK-NEXT:    [[IN_ADDR:%.*]] = alloca <4 x half>, align 8
+// CHECK-NEXT:    store double [[IN_COERCE]], ptr [[IN]], align 8
+// CHECK-NEXT:    [[IN1:%.*]] = load <4 x half>, ptr [[IN]], align 8
+// CHECK-NEXT:    store <4 x half> [[IN1]], ptr [[IN_ADDR]], align 8
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x half>, ptr [[IN_ADDR]], align 8
+// CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x half> [[TMP0]] to <4 x bfloat>
+// CHECK-NEXT:    store <4 x bfloat> [[TMP1]], ptr [[RETVAL]], align 8
+// CHECK-NEXT:    [[TMP2:%.*]] = load double, ptr [[RETVAL]], align 8
+// CHECK-NEXT:    ret double [[TMP2]]
+//
+bfloat4 test_cast_from_fp164_to_bf164(half4 in) {
+  return (bfloat4)in;
+}
+
+// CHECK-LABEL: define dso_local i32 @test_cast_from_bf162_to_fp162(
+// CHECK-SAME: i32 noundef [[IN_COERCE:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca <2 x half>, align 4
+// CHECK-NEXT:    [[IN:%.*]] = alloca <2 x bfloat>, align 4
+// CHECK-NEXT:    [[IN_ADDR:%.*]] = alloca <2 x bfloat>, align 4
+// CHECK-NEXT:    store i32 [[IN_COERCE]], ptr [[IN]], align 4
+// CHECK-NEXT:    [[IN1:%.*]] = load <2 x bfloat>, ptr [[IN]], align 4
+// CHECK-NEXT:    store <2 x bfloat> [[IN1]], ptr [[IN_ADDR]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load <2 x bfloat>, ptr [[IN_ADDR]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x bfloat> [[TMP0]] to <2 x half>
+// CHECK-NEXT:    store <2 x half> [[TMP1]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret i32 [[TMP2]]
+//
+half2 test_cast_from_bf162_to_fp162(bfloat2 in) {
+  return (half2)in;
+}
+
+
+// CHECK-LABEL: define dso_local double @test_cast_from_bf164_to_fp164(
+// CHECK-SAME: double noundef [[IN_COERCE:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca <4 x half>, align 8
+// CHECK-NEXT:    [[IN:%.*]] = alloca <4 x bfloat>, align 8
+// CHECK-NEXT:    [[IN_ADDR:%.*]] = alloca <4 x bfloat>, align 8
+// CHECK-NEXT:    store double [[IN_COERCE]], ptr [[IN]], align 8
+// CHECK-NEXT:    [[IN1:%.*]] = load <4 x bfloat>, ptr [[IN]], align 8
+// CHECK-NEXT:    store <4 x bfloat> [[IN1]], ptr [[IN_ADDR]], align 8
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x bfloat>, ptr [[IN_ADDR]], align 8
+// CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x bfloat> [[TMP0]] to <4 x half>
+// CHECK-NEXT:    store <4 x half> [[TMP1]], ptr [[RETVAL]], align 8
+// CHECK-NEXT:    [[TMP2:%.*]] = load double, ptr [[RETVAL]], align 8
+// CHECK-NEXT:    ret double [[TMP2]]
+//
+half4 test_cast_from_bf164_to_fp164(bfloat4 in) {
+  return (half4)in;
+}



More information about the cfe-commits mailing list