[clang] [CIR][NFC] Upstream VectorType support in helper function (PR #142222)

Amr Hesham via cfe-commits cfe-commits at lists.llvm.org
Sat May 31 09:01:29 PDT 2025


https://github.com/AmrDeveloper updated https://github.com/llvm/llvm-project/pull/142222

>From 77d3586fede008cf24e25d7bf30c73049de893b8 Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Sat, 31 May 2025 01:47:16 +0200
Subject: [PATCH 1/2] [CIR][NFC] Upstream VectorType support in helper function

---
 .../CIR/Dialect/IR/CIRTypeConstraints.td      | 37 +++++++++++++++++++
 clang/include/clang/CIR/Dialect/IR/CIRTypes.h |  2 -
 clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp    |  6 +--
 clang/lib/CIR/Dialect/IR/CIRTypes.cpp         |  9 -----
 4 files changed, 40 insertions(+), 14 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
index ec461cab961c7..7b20ca4e2d1d4 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
@@ -31,6 +31,18 @@ class CIR_ConfinedType<Type type, list<Pred> preds, string summary = "">
     : Type<And<[type.predicate, CIR_CastedSelfsToType<type.cppType, preds>]>,
          summary, type.cppType>;
 
+// Generates a type summary.
+// - For a single type: returns its summary.
+// - For multiple types: returns `any of <comma-separated summaries>`.
+class CIR_TypeSummaries<list<Type> types> {
+    assert !not(!empty(types)), "expects non-empty list of types";
+
+    list<string> summaries = !foreach(type, types, type.summary);
+    string joined = !interleave(summaries, ", ");
+
+    string value = !if(!eq(!size(types), 1), joined, "any of " # joined);
+}
+
 //===----------------------------------------------------------------------===//
 // Bool Type predicates
 //===----------------------------------------------------------------------===//
@@ -184,6 +196,8 @@ def CIR_PtrToVoidPtrType
 // Vector Type predicates
 //===----------------------------------------------------------------------===//
 
+def CIR_AnyVectorType : CIR_TypeBase<"::cir::VectorType", "vector type">;
+
 // Vector of integral type
 def IntegerVector : Type<
     And<[
@@ -211,4 +225,27 @@ def CIR_AnyScalarType : AnyTypeOf<CIR_ScalarTypes, "cir scalar type"> {
   let cppFunctionName = "isScalarType";
 }
 
+//===----------------------------------------------------------------------===//
+// Element type constraint bases
+//===----------------------------------------------------------------------===//
+
+class CIR_ElementTypePred<Pred pred> : SubstLeaves<"$_self",
+    "::mlir::cast<::cir::VectorType>($_self).getElementType()", pred>;
+
+class CIR_VectorTypeOf<list<Type> types, string summary = "">
+    : CIR_ConfinedType<CIR_AnyVectorType,
+        [Or<!foreach(type, types, CIR_ElementTypePred<type.predicate>)>],
+        !if(!empty(summary),
+            "vector of " # CIR_TypeSummaries<types>.value,
+            summary)>;
+
+// Vector of type constraints
+def CIR_VectorOfFloatType : CIR_VectorTypeOf<[CIR_AnyFloatType]>;
+
+def CIR_AnyFloatOrVecOfFloatType
+    : AnyTypeOf<[CIR_AnyFloatType, CIR_VectorOfFloatType],
+        "floating point or vector of floating point type"> {
+    let cppFunctionName = "isFPOrVectorOfFPType";
+}
+
 #endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.h b/clang/include/clang/CIR/Dialect/IR/CIRTypes.h
index 3845fd2a4b67d..49933be724a04 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.h
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.h
@@ -26,8 +26,6 @@ struct RecordTypeStorage;
 
 bool isValidFundamentalIntWidth(unsigned width);
 
-bool isFPOrFPVectorTy(mlir::Type);
-
 } // namespace cir
 
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 8448c164a5e58..b33bb71c99c90 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -1311,7 +1311,7 @@ mlir::Value ScalarExprEmitter::emitMul(const BinOpInfo &ops) {
       !canElideOverflowCheck(cgf.getContext(), ops))
     cgf.cgm.errorNYI("unsigned int overflow sanitizer");
 
-  if (cir::isFPOrFPVectorTy(ops.lhs.getType())) {
+  if (cir::isFPOrVectorOfFPType(ops.lhs.getType())) {
     assert(!cir::MissingFeatures::cgFPOptionsRAII());
     return builder.createFMul(loc, ops.lhs, ops.rhs);
   }
@@ -1370,7 +1370,7 @@ mlir::Value ScalarExprEmitter::emitAdd(const BinOpInfo &ops) {
       !canElideOverflowCheck(cgf.getContext(), ops))
     cgf.cgm.errorNYI("unsigned int overflow sanitizer");
 
-  if (cir::isFPOrFPVectorTy(ops.lhs.getType())) {
+  if (cir::isFPOrVectorOfFPType(ops.lhs.getType())) {
     assert(!cir::MissingFeatures::cgFPOptionsRAII());
     return builder.createFAdd(loc, ops.lhs, ops.rhs);
   }
@@ -1418,7 +1418,7 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &ops) {
         !canElideOverflowCheck(cgf.getContext(), ops))
       cgf.cgm.errorNYI("unsigned int overflow sanitizer");
 
-    if (cir::isFPOrFPVectorTy(ops.lhs.getType())) {
+    if (cir::isFPOrVectorOfFPType(ops.lhs.getType())) {
       assert(!cir::MissingFeatures::cgFPOptionsRAII());
       return builder.createFSub(loc, ops.lhs, ops.rhs);
     }
diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
index b402177a5ec18..21d957afefeb6 100644
--- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
@@ -552,15 +552,6 @@ LongDoubleType::getABIAlignment(const mlir::DataLayout &dataLayout,
       .getABIAlignment(dataLayout, params);
 }
 
-//===----------------------------------------------------------------------===//
-// Floating-point and Float-point Vector type helpers
-//===----------------------------------------------------------------------===//
-
-bool cir::isFPOrFPVectorTy(mlir::Type t) {
-  assert(!cir::MissingFeatures::vectorType());
-  return isAnyFloatingPointType(t);
-}
-
 //===----------------------------------------------------------------------===//
 // FuncType Definitions
 //===----------------------------------------------------------------------===//

>From f0a401f359f2ca56722c77951beeacc3fb1485ea Mon Sep 17 00:00:00 2001
From: AmrDeveloper <amr96 at programmer.net>
Date: Sat, 31 May 2025 17:59:36 +0200
Subject: [PATCH 2/2] Upstream Refactored vector type constraints

---
 clang/include/clang/CIR/Dialect/IR/CIROps.td  | 18 +++--
 .../CIR/Dialect/IR/CIRTypeConstraints.td      | 71 ++++++++++++-------
 .../include/clang/CIR/Dialect/IR/CIRTypes.td  | 25 +++++--
 clang/lib/CIR/Dialect/IR/CIRTypes.cpp         | 12 +---
 clang/test/CIR/IR/invalid-vector.cir          |  2 +-
 5 files changed, 78 insertions(+), 50 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 07851610a2abd..237daed32532a 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1464,9 +1464,13 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
     ```
   }];
 
-  let results = (outs CIR_AnyIntOrVecOfInt:$result);
-  let arguments = (ins CIR_AnyIntOrVecOfInt:$value, CIR_AnyIntOrVecOfInt:$amount,
-                       UnitAttr:$isShiftleft);
+  let arguments = (ins
+    CIR_AnyIntOrVecOfIntType:$value,
+    CIR_AnyIntOrVecOfIntType:$amount,
+    UnitAttr:$isShiftleft
+  );
+
+  let results = (outs CIR_AnyIntOrVecOfIntType:$result);
 
   let assemblyFormat = [{
     `(`
@@ -2050,7 +2054,7 @@ def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
     in the vector type.
   }];
 
-  let arguments = (ins Variadic<CIR_AnyType>:$elements);
+  let arguments = (ins Variadic<CIR_VectorElementType>:$elements);
   let results = (outs CIR_VectorType:$result);
 
   let assemblyFormat = [{
@@ -2085,7 +2089,7 @@ def VecInsertOp : CIR_Op<"vec.insert", [Pure,
 
   let arguments = (ins
     CIR_VectorType:$vec,
-    AnyType:$value,
+    CIR_VectorElementType:$value,
     CIR_AnyFundamentalIntType:$index
   );
 
@@ -2118,7 +2122,7 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
   }];
 
   let arguments = (ins CIR_VectorType:$vec, CIR_AnyFundamentalIntType:$index);
-  let results = (outs CIR_AnyType:$result);
+  let results = (outs CIR_VectorElementType:$result);
 
   let assemblyFormat = [{
     $vec `[` $index `:` type($index) `]` attr-dict `:` qualified(type($vec))
@@ -2180,7 +2184,7 @@ def VecShuffleDynamicOp : CIR_Op<"vec.shuffle.dynamic",
     ```
   }];
 
-  let arguments = (ins CIR_VectorType:$vec, IntegerVector:$indices);
+  let arguments = (ins CIR_VectorType:$vec, CIR_VectorOfIntType:$indices);
   let results = (outs CIR_VectorType:$result);
   let assemblyFormat = [{
     $vec `:` qualified(type($vec)) `,` $indices `:` qualified(type($indices))
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
index 7b20ca4e2d1d4..bcd516e27cc76 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
@@ -198,6 +198,22 @@ def CIR_PtrToVoidPtrType
 
 def CIR_AnyVectorType : CIR_TypeBase<"::cir::VectorType", "vector type">;
 
+def CIR_VectorElementType : AnyTypeOf<[CIR_AnyIntOrFloatType, CIR_AnyPtrType],
+    "any cir integer, floating point or pointer type"
+> {
+    let cppFunctionName = "isValidVectorTypeElementType";
+}
+
+class CIR_ElementTypePred<Pred pred> : SubstLeaves<"$_self",
+    "::mlir::cast<::cir::VectorType>($_self).getElementType()", pred>;
+
+class CIR_VectorTypeOf<list<Type> types, string summary = "">
+    : CIR_ConfinedType<CIR_AnyVectorType,
+        [Or<!foreach(type, types, CIR_ElementTypePred<type.predicate>)>],
+        !if(!empty(summary),
+            "vector of " # CIR_TypeSummaries<types>.value,
+            summary)>;
+
 // Vector of integral type
 def IntegerVector : Type<
     And<[
@@ -210,8 +226,36 @@ def IntegerVector : Type<
     ]>, "!cir.vector of !cir.int"> {
 }
 
-// Any Integer or Vector of Integer Constraints
-def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_AnyIntType, IntegerVector]>;
+// Vector of type constraints
+def CIR_VectorOfIntType : CIR_VectorTypeOf<[CIR_AnyIntType]>;
+def CIR_VectorOfUIntType : CIR_VectorTypeOf<[CIR_AnyUIntType]>;
+def CIR_VectorOfSIntType : CIR_VectorTypeOf<[CIR_AnySIntType]>;
+def CIR_VectorOfFloatType : CIR_VectorTypeOf<[CIR_AnyFloatType]>;
+
+// Vector or Scalar type constraints
+def CIR_AnyIntOrVecOfIntType
+    : AnyTypeOf<[CIR_AnyIntType, CIR_VectorOfIntType],
+        "integer or vector of integer type"> {
+    let cppFunctionName = "isIntOrVectorOfIntType";
+}
+
+def CIR_AnySIntOrVecOfSIntType
+    : AnyTypeOf<[CIR_AnySIntType, CIR_VectorOfSIntType],
+        "signed integer or vector of signed integer type"> {
+    let cppFunctionName = "isSIntOrVectorOfSIntType";
+}
+
+def CIR_AnyUIntOrVecOfUIntType
+    : AnyTypeOf<[CIR_AnyUIntType, CIR_VectorOfUIntType],
+        "unsigned integer or vector of unsigned integer type"> {
+    let cppFunctionName = "isUIntOrVectorOfUIntType";
+}
+
+def CIR_AnyFloatOrVecOfFloatType
+    : AnyTypeOf<[CIR_AnyFloatType, CIR_VectorOfFloatType],
+        "floating point or vector of floating point type"> {
+    let cppFunctionName = "isFPOrVectorOfFPType";
+}
 
 //===----------------------------------------------------------------------===//
 // Scalar Type predicates
@@ -225,27 +269,4 @@ def CIR_AnyScalarType : AnyTypeOf<CIR_ScalarTypes, "cir scalar type"> {
   let cppFunctionName = "isScalarType";
 }
 
-//===----------------------------------------------------------------------===//
-// Element type constraint bases
-//===----------------------------------------------------------------------===//
-
-class CIR_ElementTypePred<Pred pred> : SubstLeaves<"$_self",
-    "::mlir::cast<::cir::VectorType>($_self).getElementType()", pred>;
-
-class CIR_VectorTypeOf<list<Type> types, string summary = "">
-    : CIR_ConfinedType<CIR_AnyVectorType,
-        [Or<!foreach(type, types, CIR_ElementTypePred<type.predicate>)>],
-        !if(!empty(summary),
-            "vector of " # CIR_TypeSummaries<types>.value,
-            summary)>;
-
-// Vector of type constraints
-def CIR_VectorOfFloatType : CIR_VectorTypeOf<[CIR_AnyFloatType]>;
-
-def CIR_AnyFloatOrVecOfFloatType
-    : AnyTypeOf<[CIR_AnyFloatType, CIR_VectorOfFloatType],
-        "floating point or vector of floating point type"> {
-    let cppFunctionName = "isFPOrVectorOfFPType";
-}
-
 #endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
index 26f1122a4b261..9c0af8d3eaa5f 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td
@@ -275,18 +275,31 @@ def CIR_VectorType : CIR_Type<"Vector", "vector",
 
   let summary = "CIR vector type";
   let description = [{
-   `!cir.vector' represents fixed-size vector types, parameterized
-    by the element type and the number of elements.
+    The `!cir.vector` type represents a fixed-size, one-dimensional vector.
+    It takes two parameters: the element type and the number of elements.
 
-    Example:
+    Syntax:
 
     ```mlir
-    !cir.vector<!u64i x 2>
-    !cir.vector<!cir.float x 4>
+    vector-type ::= !cir.vector<size x element-type>
+    element-type ::= float-type | integer-type | pointer-type
+    ```
+
+    The `element-type` must be a scalar CIR type. Zero-sized vectors are not
+    allowed. The `size` must be a positive integer.
+
+    Examples:
+
+    ```mlir
+    !cir.vector<4 x !cir.int<u, 8>>
+    !cir.vector<2 x !cir.float>
     ```
   }];
 
-  let parameters = (ins "mlir::Type":$elementType, "uint64_t":$size);
+  let parameters = (ins
+    CIR_VectorElementType:$elementType,
+    "uint64_t":$size
+  );
 
   let assemblyFormat = [{
     `<` $size `x` $elementType `>`
diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
index 21d957afefeb6..900871c3c2cba 100644
--- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp
@@ -684,17 +684,7 @@ mlir::LogicalResult cir::VectorType::verify(
     mlir::Type elementType, uint64_t size) {
   if (size == 0)
     return emitError() << "the number of vector elements must be non-zero";
-
-  // Check if it a valid FixedVectorType
-  if (mlir::isa<cir::PointerType, cir::FP128Type>(elementType))
-    return success();
-
-  // Check if it a valid VectorType
-  if (mlir::isa<cir::IntType>(elementType) ||
-      isAnyFloatingPointType(elementType))
-    return success();
-
-  return emitError() << "unsupported element type for CIR vector";
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/clang/test/CIR/IR/invalid-vector.cir b/clang/test/CIR/IR/invalid-vector.cir
index d94eacedec1f5..72676a4718e19 100644
--- a/clang/test/CIR/IR/invalid-vector.cir
+++ b/clang/test/CIR/IR/invalid-vector.cir
@@ -4,7 +4,7 @@
 
 module  {
 
-// expected-error @below {{unsupported element type for CIR vector}}
+// expected-error @below {{failed to verify 'elementType'}}
 cir.global external @vec_b = #cir.zero : !cir.vector<4 x !cir.array<!s32i x 10>>
 
 }



More information about the cfe-commits mailing list