[clang] [Clang] [Sema] Reject non-power-of-2 `_BitInt` matrix element types (PR #117487)

via cfe-commits cfe-commits at lists.llvm.org
Wed Dec 4 00:10:51 PST 2024


https://github.com/Sirraide updated https://github.com/llvm/llvm-project/pull/117487

>From a612c8f0a78dd1f29a4885f57efbdd4a9cca374e Mon Sep 17 00:00:00 2001
From: Sirraide <aeternalmail at gmail.com>
Date: Sun, 24 Nov 2024 14:36:41 +0100
Subject: [PATCH 1/5] [Clang] [Sema] Reject non-power-of-2 `_BitInt` matrix
 element types

---
 clang/docs/ReleaseNotes.rst                   |  3 ++
 .../clang/Basic/DiagnosticSemaKinds.td        |  4 +-
 clang/lib/Sema/SemaType.cpp                   | 41 +++++++++++--------
 clang/test/SemaCXX/matrix-type.cpp            | 10 +++++
 4 files changed, 38 insertions(+), 20 deletions(-)

diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index 8bd06fadfdc984..1c0e4043bbe276 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -371,6 +371,9 @@ Non-comprehensive list of changes in this release
 - ``__builtin_reduce_and`` function can now be used in constant expressions.
 - ``__builtin_reduce_or`` and ``__builtin_reduce_xor`` functions can now be used in constant expressions.
 
+- Clang now rejects ``_BitInt`` matrix element types if the bit width is less than ``CHAR_WIDTH`` or
+  not a power of two, matching preexisting behaviour for vector types.
+
 New Compiler Flags
 ------------------
 
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index eb05a6a77978af..f049e72a6b8694 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -3233,8 +3233,8 @@ def err_attribute_too_few_arguments : Error<
   "%0 attribute takes at least %1 argument%s1">;
 def err_attribute_invalid_vector_type : Error<"invalid vector element type %0">;
 def err_attribute_invalid_bitint_vector_type : Error<
-  "'_BitInt' vector element width must be %select{a power of 2|"
-  "at least as wide as 'CHAR_BIT'}0">;
+  "'_BitInt' %select{vector|matrix}0 element width must be %select{a power of 2|"
+  "at least as wide as 'CHAR_BIT'}1">;
 def err_attribute_invalid_matrix_type : Error<"invalid matrix element type %0">;
 def err_attribute_bad_neon_vector_size : Error<
   "Neon vector size must be 64 or 128 bits">;
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index f32edc5ac06440..06b779f5ef3aa2 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2312,6 +2312,18 @@ QualType Sema::BuildArrayType(QualType T, ArraySizeModifier ASM,
   return T;
 }
 
+bool CheckBitIntElementType(Sema &S, SourceLocation AttrLoc,
+                            const BitIntType *BIT, bool ForMatrixType = false) {
+  // Only support _BitInt elements with byte-sized power of 2 NumBits.
+  unsigned NumBits = BIT->getNumBits();
+  if (!llvm::isPowerOf2_32(NumBits) || NumBits < 8) {
+    S.Diag(AttrLoc, diag::err_attribute_invalid_bitint_vector_type)
+        << ForMatrixType << (NumBits < 8);
+    return true;
+  }
+  return false;
+}
+
 QualType Sema::BuildVectorType(QualType CurType, Expr *SizeExpr,
                                SourceLocation AttrLoc) {
   // The base type must be integer (not Boolean or enumeration) or float, and
@@ -2324,15 +2336,10 @@ QualType Sema::BuildVectorType(QualType CurType, Expr *SizeExpr,
     Diag(AttrLoc, diag::err_attribute_invalid_vector_type) << CurType;
     return QualType();
   }
-  // Only support _BitInt elements with byte-sized power of 2 NumBits.
-  if (const auto *BIT = CurType->getAs<BitIntType>()) {
-    unsigned NumBits = BIT->getNumBits();
-    if (!llvm::isPowerOf2_32(NumBits) || NumBits < 8) {
-      Diag(AttrLoc, diag::err_attribute_invalid_bitint_vector_type)
-          << (NumBits < 8);
-      return QualType();
-    }
-  }
+
+  if (const auto *BIT = CurType->getAs<BitIntType>();
+      BIT && CheckBitIntElementType(*this, AttrLoc, BIT))
+    return QualType();
 
   if (SizeExpr->isTypeDependent() || SizeExpr->isValueDependent())
     return Context.getDependentVectorType(CurType, SizeExpr, AttrLoc,
@@ -2402,15 +2409,9 @@ QualType Sema::BuildExtVectorType(QualType T, Expr *ArraySize,
     return QualType();
   }
 
-  // Only support _BitInt elements with byte-sized power of 2 NumBits.
-  if (T->isBitIntType()) {
-    unsigned NumBits = T->castAs<BitIntType>()->getNumBits();
-    if (!llvm::isPowerOf2_32(NumBits) || NumBits < 8) {
-      Diag(AttrLoc, diag::err_attribute_invalid_bitint_vector_type)
-          << (NumBits < 8);
-      return QualType();
-    }
-  }
+  if (const auto *BIT = T->getAs<BitIntType>();
+      BIT && CheckBitIntElementType(*this, AttrLoc, BIT))
+    return QualType();
 
   if (!ArraySize->isTypeDependent() && !ArraySize->isValueDependent()) {
     std::optional<llvm::APSInt> vecSize =
@@ -2455,6 +2456,10 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
     return QualType();
   }
 
+  if (const auto *BIT = ElementTy->getAs<BitIntType>();
+      BIT && CheckBitIntElementType(*this, AttrLoc, BIT, true))
+    return QualType();
+
   if (NumRows->isTypeDependent() || NumCols->isTypeDependent() ||
       NumRows->isValueDependent() || NumCols->isValueDependent())
     return Context.getDependentSizedMatrixType(ElementTy, NumRows, NumCols,
diff --git a/clang/test/SemaCXX/matrix-type.cpp b/clang/test/SemaCXX/matrix-type.cpp
index af31e267fdcae8..d347332ee2d60f 100644
--- a/clang/test/SemaCXX/matrix-type.cpp
+++ b/clang/test/SemaCXX/matrix-type.cpp
@@ -29,3 +29,13 @@ void matrix_unsupported_element_type() {
   using matrix3_t = bool __attribute__((matrix_type(1, 1)));     // expected-error{{invalid matrix element type 'bool'}}
   using matrix4_t = TestEnum __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'TestEnum'}}
 }
+
+void matrix_unsupported_bit_int() {
+  using m1 = _BitInt(2) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be at least as wide as 'CHAR_BIT'}}
+  using m2 = _BitInt(7) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be at least as wide as 'CHAR_BIT'}}
+  using m3 = _BitInt(9) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be a power of 2}}
+  using m4 = _BitInt(12) __attribute__((matrix_type(4, 4))); // expected-error{{'_BitInt' matrix element width must be a power of 2}}
+  using m5 = _BitInt(8) __attribute__((matrix_type(4, 4)));
+  using m6 = _BitInt(64) __attribute__((matrix_type(4, 4)));
+  using m7 = _BitInt(256) __attribute__((matrix_type(4, 4)));
+}

>From 00fc390b339863f6e6655ccbdde99fda4932051f Mon Sep 17 00:00:00 2001
From: Sirraide <aeternalmail at gmail.com>
Date: Sun, 24 Nov 2024 15:04:14 +0100
Subject: [PATCH 2/5] Remove -pedantic

---
 clang/test/SemaCXX/matrix-type.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/clang/test/SemaCXX/matrix-type.cpp b/clang/test/SemaCXX/matrix-type.cpp
index d347332ee2d60f..bb7a8421ca9e37 100644
--- a/clang/test/SemaCXX/matrix-type.cpp
+++ b/clang/test/SemaCXX/matrix-type.cpp
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 -fsyntax-only -pedantic -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s
+// RUN: %clang_cc1 -fsyntax-only -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s
 
 using matrix_double_t = double __attribute__((matrix_type(6, 6)));
 using matrix_float_t = float __attribute__((matrix_type(6, 6)));

>From 4e4fdb98a7e1f0ac4f857838620b5af29e5b6195 Mon Sep 17 00:00:00 2001
From: Sirraide <aeternalmail at gmail.com>
Date: Mon, 25 Nov 2024 12:32:24 +0100
Subject: [PATCH 3/5] Apply feedback from code review

---
 clang/lib/Sema/SemaType.cpp | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 06b779f5ef3aa2..03308c067a9c8f 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2312,15 +2312,14 @@ QualType Sema::BuildArrayType(QualType T, ArraySizeModifier ASM,
   return T;
 }
 
-bool CheckBitIntElementType(Sema &S, SourceLocation AttrLoc,
-                            const BitIntType *BIT, bool ForMatrixType = false) {
+static bool CheckBitIntElementType(Sema &S, SourceLocation AttrLoc,
+                                   const BitIntType *BIT,
+                                   bool ForMatrixType = false) {
   // Only support _BitInt elements with byte-sized power of 2 NumBits.
   unsigned NumBits = BIT->getNumBits();
-  if (!llvm::isPowerOf2_32(NumBits) || NumBits < 8) {
-    S.Diag(AttrLoc, diag::err_attribute_invalid_bitint_vector_type)
-        << ForMatrixType << (NumBits < 8);
-    return true;
-  }
+  if (!llvm::isPowerOf2_32(NumBits) || NumBits < 8)
+    return S.Diag(AttrLoc, diag::err_attribute_invalid_bitint_vector_type)
+           << ForMatrixType << (NumBits < 8);
   return false;
 }
 
@@ -2457,7 +2456,8 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
   }
 
   if (const auto *BIT = ElementTy->getAs<BitIntType>();
-      BIT && CheckBitIntElementType(*this, AttrLoc, BIT, true))
+      BIT &&
+      CheckBitIntElementType(*this, AttrLoc, BIT, /*ForMatrixType=*/true))
     return QualType();
 
   if (NumRows->isTypeDependent() || NumCols->isTypeDependent() ||

>From 7a54083ef7916ebfb05d4451d61e0d83d7400641 Mon Sep 17 00:00:00 2001
From: Sirraide <aeternalmail at gmail.com>
Date: Wed, 4 Dec 2024 09:09:47 +0100
Subject: [PATCH 4/5] Add tests for _BitInt matrix and vector types

---
 .../test/CodeGenCXX/matrix-vector-bit-int.cpp | 98 +++++++++++++++++++
 1 file changed, 98 insertions(+)
 create mode 100644 clang/test/CodeGenCXX/matrix-vector-bit-int.cpp

diff --git a/clang/test/CodeGenCXX/matrix-vector-bit-int.cpp b/clang/test/CodeGenCXX/matrix-vector-bit-int.cpp
new file mode 100644
index 00000000000000..040615f4170855
--- /dev/null
+++ b/clang/test/CodeGenCXX/matrix-vector-bit-int.cpp
@@ -0,0 +1,98 @@
+// RUN: %clang_cc1 -fenable-matrix %s -emit-llvm -triple x86_64-unknown-linux -disable-llvm-passes -o - -std=c++11 | FileCheck %s
+
+using i8x3 = _BitInt(8) __attribute__((ext_vector_type(3)));
+using i8x3x3 = _BitInt(8) __attribute__((matrix_type(3, 3)));
+using i32x3 = _BitInt(32) __attribute__((ext_vector_type(3)));
+using i32x3x3 = _BitInt(32) __attribute__((matrix_type(3, 3)));
+using i512x3 = _BitInt(512) __attribute__((ext_vector_type(3)));
+using i512x3x3 = _BitInt(512) __attribute__((matrix_type(3, 3)));
+
+// CHECK-LABEL: define dso_local i32 @_Z2v1Dv3_DB8_(i32 %a.coerce)
+i8x3 v1(i8x3 a) {
+  // CHECK-NEXT: entry:
+  // CHECK-NEXT:   %retval = alloca <3 x i8>, align 4
+  // CHECK-NEXT:   %a = alloca <3 x i8>, align 4
+  // CHECK-NEXT:   %a.addr = alloca <3 x i8>, align 4
+  // CHECK-NEXT:   store i32 %a.coerce, ptr %a, align 4
+  // CHECK-NEXT:   %loadVec4 = load <4 x i8>, ptr %a, align 4
+  // CHECK-NEXT:   %a1 = shufflevector <4 x i8> %loadVec4, <4 x i8> poison, <3 x i32> <i32 0, i32 1, i32 2>
+  // CHECK-NEXT:   %extractVec = shufflevector <3 x i8> %a1, <3 x i8> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 poison>
+  // CHECK-NEXT:   store <4 x i8> %extractVec, ptr %a.addr, align 4
+  // CHECK-NEXT:   %loadVec42 = load <4 x i8>, ptr %a.addr, align 4
+  // CHECK-NEXT:   %extractVec3 = shufflevector <4 x i8> %loadVec42, <4 x i8> poison, <3 x i32> <i32 0, i32 1, i32 2>
+  // CHECK-NEXT:   %loadVec44 = load <4 x i8>, ptr %a.addr, align 4
+  // CHECK-NEXT:   %extractVec5 = shufflevector <4 x i8> %loadVec44, <4 x i8> poison, <3 x i32> <i32 0, i32 1, i32 2>
+  // CHECK-NEXT:   %add = add <3 x i8> %extractVec3, %extractVec5
+  // CHECK-NEXT:   store <3 x i8> %add, ptr %retval, align 4
+  // CHECK-NEXT:   %0 = load i32, ptr %retval, align 4
+  // CHECK-NEXT:   ret i32 %0
+  return a + a;
+}
+
+// CHECK-LABEL: define dso_local noundef <3 x i32> @_Z2v2Dv3_DB32_(<3 x i32> noundef %a)
+i32x3 v2(i32x3 a) {
+  // CHECK-NEXT: entry:
+  // CHECK-NEXT:   %a.addr = alloca <3 x i32>, align 16
+  // CHECK-NEXT:   %extractVec = shufflevector <3 x i32> %a, <3 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 poison>
+  // CHECK-NEXT:   store <4 x i32> %extractVec, ptr %a.addr, align 16
+  // CHECK-NEXT:   %loadVec4 = load <4 x i32>, ptr %a.addr, align 16
+  // CHECK-NEXT:   %extractVec1 = shufflevector <4 x i32> %loadVec4, <4 x i32> poison, <3 x i32> <i32 0, i32 1, i32 2>
+  // CHECK-NEXT:   %loadVec42 = load <4 x i32>, ptr %a.addr, align 16
+  // CHECK-NEXT:   %extractVec3 = shufflevector <4 x i32> %loadVec42, <4 x i32> poison, <3 x i32> <i32 0, i32 1, i32 2>
+  // CHECK-NEXT:   %add = add <3 x i32> %extractVec1, %extractVec3
+  // CHECK-NEXT:   ret <3 x i32> %add
+  return a + a;
+}
+
+// CHECK-LABEL: define dso_local noundef <3 x i512> @_Z2v3Dv3_DB512_(ptr noundef byval(<3 x i512>) align 256 %0)
+i512x3 v3(i512x3 a) {
+  // CHECK-NEXT: entry:
+  // CHECK-NEXT:   %a.addr = alloca <3 x i512>, align 256
+  // CHECK-NEXT:   %loadVec4 = load <4 x i512>, ptr %0, align 256
+  // CHECK-NEXT:   %a = shufflevector <4 x i512> %loadVec4, <4 x i512> poison, <3 x i32> <i32 0, i32 1, i32 2>
+  // CHECK-NEXT:   %extractVec = shufflevector <3 x i512> %a, <3 x i512> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 poison>
+  // CHECK-NEXT:   store <4 x i512> %extractVec, ptr %a.addr, align 256
+  // CHECK-NEXT:   %loadVec41 = load <4 x i512>, ptr %a.addr, align 256
+  // CHECK-NEXT:   %extractVec2 = shufflevector <4 x i512> %loadVec41, <4 x i512> poison, <3 x i32> <i32 0, i32 1, i32 2>
+  // CHECK-NEXT:   %loadVec43 = load <4 x i512>, ptr %a.addr, align 256
+  // CHECK-NEXT:   %extractVec4 = shufflevector <4 x i512> %loadVec43, <4 x i512> poison, <3 x i32> <i32 0, i32 1, i32 2>
+  // CHECK-NEXT:   %add = add <3 x i512> %extractVec2, %extractVec4
+  // CHECK-NEXT:   ret <3 x i512> %add
+  return a + a;
+}
+
+// CHECK-LABEL: define dso_local noundef <9 x i8> @_Z2m1u11matrix_typeILm3ELm3EDB8_E(<9 x i8> noundef %a)
+i8x3x3 m1(i8x3x3 a) {
+  // CHECK-NEXT: entry:
+  // CHECK-NEXT:   %a.addr = alloca [9 x i8], align 1
+  // CHECK-NEXT:   store <9 x i8> %a, ptr %a.addr, align 1
+  // CHECK-NEXT:   %0 = load <9 x i8>, ptr %a.addr, align 1
+  // CHECK-NEXT:   %1 = load <9 x i8>, ptr %a.addr, align 1
+  // CHECK-NEXT:   %2 = add <9 x i8> %0, %1
+  // CHECK-NEXT:   ret <9 x i8> %2
+  return a + a;
+}
+
+// CHECK-LABEL: define dso_local noundef <9 x i32> @_Z2m2u11matrix_typeILm3ELm3EDB32_E(<9 x i32> noundef %a)
+i32x3x3 m2(i32x3x3 a) {
+  // CHECK-NEXT: entry:
+  // CHECK-NEXT:   %a.addr = alloca [9 x i32], align 4
+  // CHECK-NEXT:   store <9 x i32> %a, ptr %a.addr, align 4
+  // CHECK-NEXT:   %0 = load <9 x i32>, ptr %a.addr, align 4
+  // CHECK-NEXT:   %1 = load <9 x i32>, ptr %a.addr, align 4
+  // CHECK-NEXT:   %2 = add <9 x i32> %0, %1
+  // CHECK-NEXT:   ret <9 x i32> %2
+  return a + a;
+}
+
+// CHECK-LABEL: define dso_local noundef <9 x i512> @_Z2m3u11matrix_typeILm3ELm3EDB512_E(<9 x i512> noundef %a)
+i512x3x3 m3(i512x3x3 a) {
+  // CHECK-NEXT: entry:
+  // CHECK-NEXT:   %a.addr = alloca [9 x i512], align 8
+  // CHECK-NEXT:   store <9 x i512> %a, ptr %a.addr, align 8
+  // CHECK-NEXT:   %0 = load <9 x i512>, ptr %a.addr, align 8
+  // CHECK-NEXT:   %1 = load <9 x i512>, ptr %a.addr, align 8
+  // CHECK-NEXT:   %2 = add <9 x i512> %0, %1
+  // CHECK-NEXT:   ret <9 x i512> %2
+  return a + a;
+}

>From 0d8fd8fbe1b9b018871c0a120d4484da56ff1d7e Mon Sep 17 00:00:00 2001
From: Sirraide <aeternalmail at gmail.com>
Date: Wed, 4 Dec 2024 09:10:34 +0100
Subject: [PATCH 5/5] Update docs to mention that non-power-of-2 _BitInt
 matrices are disallowed

---
 clang/docs/MatrixTypes.rst | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/clang/docs/MatrixTypes.rst b/clang/docs/MatrixTypes.rst
index e32e13b73aba61..4b3a0f42f20b4f 100644
--- a/clang/docs/MatrixTypes.rst
+++ b/clang/docs/MatrixTypes.rst
@@ -33,9 +33,10 @@ program is ill-formed.
 Currently, the element type of a matrix is only permitted to be one of the
 following types:
 
-* an integer type (as in C23 6.2.5p22), but excluding enumerated types and ``bool``
-* the standard floating types ``float`` or ``double``
-* a half-precision floating point type, if one is supported on the target
+* an integer type (as in C23 6.2.5p22), but excluding enumerated types, ``bool``,
+  and ``_BitInt``s whose width is not a power of 2;
+* the standard floating types ``float`` or ``double``;
+* a half-precision floating point type, if one is supported on the target.
 
 Other types may be supported in the future.
 



More information about the cfe-commits mailing list