[clang] e9e1d47 - [X86] Refactor GetSSETypeAtOffset to fix pr51813

via cfe-commits cfe-commits at lists.llvm.org
Thu Sep 16 20:13:20 PDT 2021


Author: Wang, Pengfei
Date: 2021-09-17T10:51:59+08:00
New Revision: e9e1d4751b54126743ed72b1a9178ee51200acf6

URL: https://github.com/llvm/llvm-project/commit/e9e1d4751b54126743ed72b1a9178ee51200acf6
DIFF: https://github.com/llvm/llvm-project/commit/e9e1d4751b54126743ed72b1a9178ee51200acf6.diff

LOG: [X86] Refactor GetSSETypeAtOffset to fix pr51813

D105263 adds support for _Float16 type. It introduced a bug (pr51813) that generates a <4 x half> type instead the default double when passing blank structure by SSE registers.

Although I doubt it may expose a bug somewhere other than D105263, it's good to avoid return half type when no half type in arguments.

Reviewed By: LuoYuanke

Differential Revision: https://reviews.llvm.org/D109607

Added: 
    

Modified: 
    clang/lib/CodeGen/TargetInfo.cpp
    clang/test/CodeGen/X86/avx512fp16-abi.c

Removed: 
    


################################################################################
diff  --git a/clang/lib/CodeGen/TargetInfo.cpp b/clang/lib/CodeGen/TargetInfo.cpp
index 67d0c2b5c850..58d43bd36565 100644
--- a/clang/lib/CodeGen/TargetInfo.cpp
+++ b/clang/lib/CodeGen/TargetInfo.cpp
@@ -3407,52 +3407,18 @@ static bool BitsContainNoUserData(QualType Ty, unsigned StartBit,
   return false;
 }
 
-/// ContainsFloatAtOffset - Return true if the specified LLVM IR type has a
-/// float member at the specified offset.  For example, {int,{float}} has a
-/// float at offset 4.  It is conservatively correct for this routine to return
-/// false.
-static bool ContainsFloatAtOffset(llvm::Type *IRType, unsigned IROffset,
-                                  const llvm::DataLayout &TD) {
-  // Base case if we find a float.
-  if (IROffset == 0 && IRType->isFloatTy())
-    return true;
-
-  // If this is a struct, recurse into the field at the specified offset.
-  if (llvm::StructType *STy = dyn_cast<llvm::StructType>(IRType)) {
-    const llvm::StructLayout *SL = TD.getStructLayout(STy);
-    unsigned Elt = SL->getElementContainingOffset(IROffset);
-    IROffset -= SL->getElementOffset(Elt);
-    return ContainsFloatAtOffset(STy->getElementType(Elt), IROffset, TD);
-  }
-
-  // If this is an array, recurse into the field at the specified offset.
-  if (llvm::ArrayType *ATy = dyn_cast<llvm::ArrayType>(IRType)) {
-    llvm::Type *EltTy = ATy->getElementType();
-    unsigned EltSize = TD.getTypeAllocSize(EltTy);
-    IROffset -= IROffset/EltSize*EltSize;
-    return ContainsFloatAtOffset(EltTy, IROffset, TD);
-  }
-
-  return false;
-}
-
-/// ContainsHalfAtOffset - Return true if the specified LLVM IR type has a
-/// half member at the specified offset.  For example, {int,{half}} has a
-/// half at offset 4.  It is conservatively correct for this routine to return
-/// false.
-/// FIXME: Merge with ContainsFloatAtOffset
-static bool ContainsHalfAtOffset(llvm::Type *IRType, unsigned IROffset,
-                                 const llvm::DataLayout &TD) {
-  // Base case if we find a float.
-  if (IROffset == 0 && IRType->isHalfTy())
-    return true;
+/// getFPTypeAtOffset - Return a floating point type at the specified offset.
+static llvm::Type *getFPTypeAtOffset(llvm::Type *IRType, unsigned IROffset,
+                                     const llvm::DataLayout &TD) {
+  if (IROffset == 0 && IRType->isFloatingPointTy())
+    return IRType;
 
   // If this is a struct, recurse into the field at the specified offset.
   if (llvm::StructType *STy = dyn_cast<llvm::StructType>(IRType)) {
     const llvm::StructLayout *SL = TD.getStructLayout(STy);
     unsigned Elt = SL->getElementContainingOffset(IROffset);
     IROffset -= SL->getElementOffset(Elt);
-    return ContainsHalfAtOffset(STy->getElementType(Elt), IROffset, TD);
+    return getFPTypeAtOffset(STy->getElementType(Elt), IROffset, TD);
   }
 
   // If this is an array, recurse into the field at the specified offset.
@@ -3460,10 +3426,10 @@ static bool ContainsHalfAtOffset(llvm::Type *IRType, unsigned IROffset,
     llvm::Type *EltTy = ATy->getElementType();
     unsigned EltSize = TD.getTypeAllocSize(EltTy);
     IROffset -= IROffset / EltSize * EltSize;
-    return ContainsHalfAtOffset(EltTy, IROffset, TD);
+    return getFPTypeAtOffset(EltTy, IROffset, TD);
   }
 
-  return false;
+  return nullptr;
 }
 
 /// GetSSETypeAtOffset - Return a type that will be passed by the backend in the
@@ -3471,39 +3437,37 @@ static bool ContainsHalfAtOffset(llvm::Type *IRType, unsigned IROffset,
 llvm::Type *X86_64ABIInfo::
 GetSSETypeAtOffset(llvm::Type *IRType, unsigned IROffset,
                    QualType SourceTy, unsigned SourceOffset) const {
-  // If the high 32 bits are not used, we have three choices. Single half,
-  // single float or two halfs.
-  if (BitsContainNoUserData(SourceTy, SourceOffset * 8 + 32,
-                            SourceOffset * 8 + 64, getContext())) {
-    if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout()))
-      return llvm::Type::getFloatTy(getVMContext());
-    if (ContainsHalfAtOffset(IRType, IROffset + 2, getDataLayout()))
-      return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()),
-                                        2);
-
-    return llvm::Type::getHalfTy(getVMContext());
-  }
-
-  // We want to pass as <2 x float> if the LLVM IR type contains a float at
-  // offset+0 and offset+4. Walk the LLVM IR type to find out if this is the
-  // case.
-  if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout()) &&
-      ContainsFloatAtOffset(IRType, IROffset + 4, getDataLayout()))
-    return llvm::FixedVectorType::get(llvm::Type::getFloatTy(getVMContext()),
-                                      2);
-
-  // We want to pass as <4 x half> if the LLVM IR type contains a half at
-  // offset+0, +2, +4. Walk the LLVM IR type to find out if this is the case.
-  if (ContainsHalfAtOffset(IRType, IROffset, getDataLayout()) &&
-      ContainsHalfAtOffset(IRType, IROffset + 2, getDataLayout()) &&
-      ContainsHalfAtOffset(IRType, IROffset + 4, getDataLayout()))
-    return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()), 4);
-
-  // We want to pass as <4 x half> if the LLVM IR type contains a mix of float
-  // and half.
-  // FIXME: Do we have a better representation for the mixed type?
-  if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout()) ||
-      ContainsFloatAtOffset(IRType, IROffset + 4, getDataLayout()))
+  const llvm::DataLayout &TD = getDataLayout();
+  llvm::Type *T0 = getFPTypeAtOffset(IRType, IROffset, TD);
+  if (!T0 || T0->isDoubleTy())
+    return llvm::Type::getDoubleTy(getVMContext());
+
+  // Get the adjacent FP type.
+  llvm::Type *T1 =
+      getFPTypeAtOffset(IRType, IROffset + TD.getTypeAllocSize(T0), TD);
+  if (T1 == nullptr) {
+    // Check if IRType is a half + float. float type will be in IROffset+4 due
+    // to its alignment.
+    if (T0->isHalfTy())
+      T1 = getFPTypeAtOffset(IRType, IROffset + 4, TD);
+    // If we can't get a second FP type, return a simple half or float.
+    // avx512fp16-abi.c:pr51813_2 shows it works to return float for
+    // {float, i8} too.
+    if (T1 == nullptr)
+      return T0;
+  }
+
+  if (T0->isFloatTy() && T1->isFloatTy())
+    return llvm::FixedVectorType::get(T0, 2);
+
+  if (T0->isHalfTy() && T1->isHalfTy()) {
+    llvm::Type *T2 = getFPTypeAtOffset(IRType, IROffset + 4, TD);
+    if (T2 == nullptr)
+      return llvm::FixedVectorType::get(T0, 2);
+    return llvm::FixedVectorType::get(T0, 4);
+  }
+
+  if (T0->isHalfTy() || T1->isHalfTy())
     return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()), 4);
 
   return llvm::Type::getDoubleTy(getVMContext());

diff  --git a/clang/test/CodeGen/X86/avx512fp16-abi.c b/clang/test/CodeGen/X86/avx512fp16-abi.c
index dd5bf6a89c1d..5018567d5bb4 100644
--- a/clang/test/CodeGen/X86/avx512fp16-abi.c
+++ b/clang/test/CodeGen/X86/avx512fp16-abi.c
@@ -1,11 +1,12 @@
-// RUN: %clang_cc1 -triple x86_64-linux -emit-llvm  -target-feature +avx512fp16 < %s | FileCheck %s --check-prefixes=CHECK
+// RUN: %clang_cc1 -triple x86_64-linux -emit-llvm  -target-feature +avx512fp16 < %s | FileCheck %s --check-prefixes=CHECK,CHECK-C
+// RUN: %clang_cc1 -triple x86_64-linux -emit-llvm  -target-feature +avx512fp16 -x c++ -std=c++11 < %s | FileCheck %s --check-prefixes=CHECK,CHECK-CPP
 
 struct half1 {
   _Float16 a;
 };
 
 struct half1 h1(_Float16 a) {
-  // CHECK: define{{.*}}half @h1
+  // CHECK: define{{.*}}half @
   struct half1 x;
   x.a = a;
   return x;
@@ -17,7 +18,7 @@ struct half2 {
 };
 
 struct half2 h2(_Float16 a, _Float16 b) {
-  // CHECK: define{{.*}}<2 x half> @h2
+  // CHECK: define{{.*}}<2 x half> @
   struct half2 x;
   x.a = a;
   x.b = b;
@@ -31,7 +32,7 @@ struct half3 {
 };
 
 struct half3 h3(_Float16 a, _Float16 b, _Float16 c) {
-  // CHECK: define{{.*}}<4 x half> @h3
+  // CHECK: define{{.*}}<4 x half> @
   struct half3 x;
   x.a = a;
   x.b = b;
@@ -47,7 +48,7 @@ struct half4 {
 };
 
 struct half4 h4(_Float16 a, _Float16 b, _Float16 c, _Float16 d) {
-  // CHECK: define{{.*}}<4 x half> @h4
+  // CHECK: define{{.*}}<4 x half> @
   struct half4 x;
   x.a = a;
   x.b = b;
@@ -62,7 +63,7 @@ struct floathalf {
 };
 
 struct floathalf fh(float a, _Float16 b) {
-  // CHECK: define{{.*}}<4 x half> @fh
+  // CHECK: define{{.*}}<4 x half> @
   struct floathalf x;
   x.a = a;
   x.b = b;
@@ -76,7 +77,7 @@ struct floathalf2 {
 };
 
 struct floathalf2 fh2(float a, _Float16 b, _Float16 c) {
-  // CHECK: define{{.*}}<4 x half> @fh2
+  // CHECK: define{{.*}}<4 x half> @
   struct floathalf2 x;
   x.a = a;
   x.b = b;
@@ -90,7 +91,7 @@ struct halffloat {
 };
 
 struct halffloat hf(_Float16 a, float b) {
-  // CHECK: define{{.*}}<4 x half> @hf
+  // CHECK: define{{.*}}<4 x half> @
   struct halffloat x;
   x.a = a;
   x.b = b;
@@ -104,7 +105,7 @@ struct half2float {
 };
 
 struct half2float h2f(_Float16 a, _Float16 b, float c) {
-  // CHECK: define{{.*}}<4 x half> @h2f
+  // CHECK: define{{.*}}<4 x half> @
   struct half2float x;
   x.a = a;
   x.b = b;
@@ -120,7 +121,7 @@ struct floathalf3 {
 };
 
 struct floathalf3 fh3(float a, _Float16 b, _Float16 c, _Float16 d) {
-  // CHECK: define{{.*}}{ <4 x half>, half } @fh3
+  // CHECK: define{{.*}}{ <4 x half>, half } @
   struct floathalf3 x;
   x.a = a;
   x.b = b;
@@ -138,7 +139,7 @@ struct half5 {
 };
 
 struct half5 h5(_Float16 a, _Float16 b, _Float16 c, _Float16 d, _Float16 e) {
-  // CHECK: define{{.*}}{ <4 x half>, half } @h5
+  // CHECK: define{{.*}}{ <4 x half>, half } @
   struct half5 x;
   x.a = a;
   x.b = b;
@@ -147,3 +148,52 @@ struct half5 h5(_Float16 a, _Float16 b, _Float16 c, _Float16 d, _Float16 e) {
   x.e = e;
   return x;
 }
+
+struct float2 {
+  struct {} s;
+  float a;
+  float b;
+};
+
+float pr51813(struct float2 s) {
+  // CHECK-C: define{{.*}} @pr51813(<2 x float>
+  // CHECK-CPP: define{{.*}} @_Z7pr518136float2(double {{.*}}, float
+  return s.a;
+}
+
+struct float3 {
+  float a;
+  struct {} s;
+  float b;
+};
+
+float pr51813_2(struct float3 s) {
+  // CHECK-C: define{{.*}} @pr51813_2(<2 x float>
+  // CHECK-CPP: define{{.*}} @_Z9pr51813_26float3(double {{.*}}, float
+  return s.a;
+}
+
+struct shalf2 {
+  struct {} s;
+  _Float16 a;
+  _Float16 b;
+};
+
+_Float16 sf2(struct shalf2 s) {
+  // CHECK-C: define{{.*}} @sf2(<2 x half>
+  // CHECK-CPP: define{{.*}} @_Z3sf26shalf2(double {{.*}}
+  return s.a;
+};
+
+struct halfs2 {
+  _Float16 a;
+  struct {} s1;
+  _Float16 b;
+  struct {} s2;
+};
+
+_Float16 fs2(struct shalf2 s) {
+  // CHECK-C: define{{.*}} @fs2(<2 x half>
+  // CHECK-CPP: define{{.*}} @_Z3fs26shalf2(double {{.*}}
+  return s.a;
+};


        


More information about the cfe-commits mailing list