[llvm] [LLVM] Support target extension types in vectors (PR #140630)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 5 13:42:40 PDT 2025


https://github.com/arysef updated https://github.com/llvm/llvm-project/pull/140630

>From 78a86cf8c1ca1bc3f0336ae0dcf71fab5ef508b9 Mon Sep 17 00:00:00 2001
From: Aryan Sefidi <asefidi at microsoft.com>
Date: Mon, 19 May 2025 21:41:01 +0000
Subject: [PATCH 1/2] Support target extension types in vectors

---
 llvm/lib/IR/Type.cpp                            |  3 ++-
 llvm/test/Verifier/target-ext-vector-invalid.ll |  8 ++++++++
 llvm/test/Verifier/target-ext-vector.ll         | 11 +++++++++++
 3 files changed, 21 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/Verifier/target-ext-vector-invalid.ll
 create mode 100644 llvm/test/Verifier/target-ext-vector.ll

diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp
index 4b43b52014484..852b4897ed013 100644
--- a/llvm/lib/IR/Type.cpp
+++ b/llvm/lib/IR/Type.cpp
@@ -791,7 +791,8 @@ VectorType *VectorType::get(Type *ElementType, ElementCount EC) {
 
 bool VectorType::isValidElementType(Type *ElemTy) {
   return ElemTy->isIntegerTy() || ElemTy->isFloatingPointTy() ||
-         ElemTy->isPointerTy() || ElemTy->getTypeID() == TypedPointerTyID;
+         ElemTy->isPointerTy() || ElemTy->getTypeID() == TypedPointerTyID || 
+         (ElemTy->isTargetExtTy() && ElemTy->isSized());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/test/Verifier/target-ext-vector-invalid.ll b/llvm/test/Verifier/target-ext-vector-invalid.ll
new file mode 100644
index 0000000000000..2b6a785db986c
--- /dev/null
+++ b/llvm/test/Verifier/target-ext-vector-invalid.ll
@@ -0,0 +1,8 @@
+; RUN: not llvm-as %s -o /dev/null 2>&1 | FileCheck %s
+
+; CHECK: invalid vector element type
+
+define void @bad() {
+  %v = alloca <2 x target("spirv.IntegralConstant")>
+  ret void
+}
\ No newline at end of file
diff --git a/llvm/test/Verifier/target-ext-vector.ll b/llvm/test/Verifier/target-ext-vector.ll
new file mode 100644
index 0000000000000..a9129d4dfaa55
--- /dev/null
+++ b/llvm/test/Verifier/target-ext-vector.ll
@@ -0,0 +1,11 @@
+; RUN: llvm-as -o - %s | llvm-dis | FileCheck %s
+
+; CHECK-LABEL: @vec_ops(
+define <2 x target("spirv.Image")> @vec_ops(<2 x target("spirv.Image")> %x) {
+  %a = alloca <2 x target("spirv.Image")>
+  store <2 x target("spirv.Image")> %x, ptr %a
+  %load = load <2 x target("spirv.Image")>, ptr %a
+  %elt = extractelement <2 x target("spirv.Image")> %load, i64 0
+  %res = insertelement <2 x target("spirv.Image")> undef, target("spirv.Image") %elt, i64 1
+  ret <2 x target("spirv.Image")> %res
+}
\ No newline at end of file

>From 33244621c32114827800f2191d444c7f82043af1 Mon Sep 17 00:00:00 2001
From: Aryan Sefidi <asefidi at microsoft.com>
Date: Thu, 5 Jun 2025 20:42:26 +0000
Subject: [PATCH 2/2] add property to opt in to being a valid vector element
 type

---
 llvm/docs/LangRef.rst                         |  3 +-
 llvm/include/llvm/IR/DerivedTypes.h           |  2 ++
 llvm/lib/IR/Type.cpp                          | 22 ++++++++++----
 .../Verifier/target-ext-vector-invalid.ll     |  2 +-
 llvm/test/Verifier/target-ext-vector.ll       | 29 ++++++++++++-------
 5 files changed, 41 insertions(+), 17 deletions(-)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 0958f6a4b729b..e3b7c223b9c8d 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -4438,7 +4438,8 @@ the type size is smaller than the type's store size.
       < vscale x <# elements> x <elementtype> > ; Scalable vector
 
 The number of elements is a constant integer value larger than 0;
-elementtype may be any integer, floating-point or pointer type. Vectors
+elementtype may be any integer, floating-point, pointer type, or a sized  
+target extension type that has the `CanBeVectorElement` property. Vectors
 of size zero are not allowed. For scalable vectors, the total number of
 elements is a constant multiple (called vscale) of the specified number
 of elements; vscale is a positive integer that is unknown at compile time
diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h
index 4d6bb1cfe3069..fa62bc09b61a3 100644
--- a/llvm/include/llvm/IR/DerivedTypes.h
+++ b/llvm/include/llvm/IR/DerivedTypes.h
@@ -845,6 +845,8 @@ class TargetExtType : public Type {
     /// This type may be allocated on the stack, either as the allocated type
     /// of an alloca instruction or as a byval function parameter.
     CanBeLocal = 1U << 2,
+    // This type may be used as an element in a vector.
+    CanBeVectorElement = 1U << 3,
   };
 
   /// Returns true if the target extension type contains the given property.
diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp
index 852b4897ed013..7858e24f4fce7 100644
--- a/llvm/lib/IR/Type.cpp
+++ b/llvm/lib/IR/Type.cpp
@@ -790,9 +790,13 @@ VectorType *VectorType::get(Type *ElementType, ElementCount EC) {
 }
 
 bool VectorType::isValidElementType(Type *ElemTy) {
-  return ElemTy->isIntegerTy() || ElemTy->isFloatingPointTy() ||
-         ElemTy->isPointerTy() || ElemTy->getTypeID() == TypedPointerTyID || 
-         (ElemTy->isTargetExtTy() && ElemTy->isSized());
+  if (ElemTy->isIntegerTy() || ElemTy->isFloatingPointTy() ||
+      ElemTy->isPointerTy() || ElemTy->getTypeID() == TypedPointerTyID)
+    return true;
+  if (auto *TTy = dyn_cast<TargetExtType>(ElemTy))
+    return TTy->hasProperty(TargetExtType::CanBeVectorElement) &&
+           TTy->isSized();
+  return false;
 }
 
 //===----------------------------------------------------------------------===//
@@ -802,8 +806,9 @@ bool VectorType::isValidElementType(Type *ElemTy) {
 FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) {
   assert(NumElts > 0 && "#Elements of a VectorType must be greater than 0");
   assert(isValidElementType(ElementType) && "Element type of a VectorType must "
-                                            "be an integer, floating point, or "
-                                            "pointer type.");
+                                            "be an integer, floating point, "
+                                            "pointer type, or a valid target "
+                                            "extension type.");
 
   auto EC = ElementCount::getFixed(NumElts);
 
@@ -1038,6 +1043,13 @@ static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) {
                           TargetExtType::CanBeGlobal);
   }
 
+  // Type used to test vector element target extension property.
+  // Can be removed once a public target extension type uses CanBeVectorElement
+  if (Name == "llvm.test.vectorelement") {
+    return TargetTypeInfo(Type::getInt32Ty(C), TargetExtType::CanBeLocal,
+                          TargetExtType::CanBeVectorElement);
+  }
+
   return TargetTypeInfo(Type::getVoidTy(C));
 }
 
diff --git a/llvm/test/Verifier/target-ext-vector-invalid.ll b/llvm/test/Verifier/target-ext-vector-invalid.ll
index 2b6a785db986c..59e3e78276a06 100644
--- a/llvm/test/Verifier/target-ext-vector-invalid.ll
+++ b/llvm/test/Verifier/target-ext-vector-invalid.ll
@@ -3,6 +3,6 @@
 ; CHECK: invalid vector element type
 
 define void @bad() {
-  %v = alloca <2 x target("spirv.IntegralConstant")>
+  %v = alloca <2 x target("spirv.Image")>
   ret void
 }
\ No newline at end of file
diff --git a/llvm/test/Verifier/target-ext-vector.ll b/llvm/test/Verifier/target-ext-vector.ll
index a9129d4dfaa55..43d8360f39da0 100644
--- a/llvm/test/Verifier/target-ext-vector.ll
+++ b/llvm/test/Verifier/target-ext-vector.ll
@@ -1,11 +1,20 @@
-; RUN: llvm-as -o - %s | llvm-dis | FileCheck %s
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=verify -S %s | FileCheck %s
 
-; CHECK-LABEL: @vec_ops(
-define <2 x target("spirv.Image")> @vec_ops(<2 x target("spirv.Image")> %x) {
-  %a = alloca <2 x target("spirv.Image")>
-  store <2 x target("spirv.Image")> %x, ptr %a
-  %load = load <2 x target("spirv.Image")>, ptr %a
-  %elt = extractelement <2 x target("spirv.Image")> %load, i64 0
-  %res = insertelement <2 x target("spirv.Image")> undef, target("spirv.Image") %elt, i64 1
-  ret <2 x target("spirv.Image")> %res
-}
\ No newline at end of file
+define <2 x target("llvm.test.vectorelement")> @vec_ops(<2 x target("llvm.test.vectorelement")> %x) {
+; CHECK-LABEL: define <2 x target("llvm.test.vectorelement")> @vec_ops(
+; CHECK-SAME: <2 x target("llvm.test.vectorelement")> [[X:%.*]]) {
+; CHECK-NEXT:    [[A:%.*]] = alloca <2 x target("llvm.test.vectorelement")>{{.*}}
+; CHECK-NEXT:    store <2 x target("llvm.test.vectorelement")> [[X]], ptr [[A]], {{.*}}
+; CHECK-NEXT:    [[LOAD:%.*]] = load <2 x target("llvm.test.vectorelement")>, ptr [[A]], {{.*}}
+; CHECK-NEXT:    [[ELT:%.*]] = extractelement <2 x target("llvm.test.vectorelement")> [[LOAD]], i64 0
+; CHECK-NEXT:    [[RES:%.*]] = insertelement <2 x target("llvm.test.vectorelement")> undef, target("llvm.test.vectorelement") [[ELT]], i64 1
+; CHECK-NEXT:    ret <2 x target("llvm.test.vectorelement")> [[RES]]
+;
+  %a = alloca <2 x target("llvm.test.vectorelement")>
+  store <2 x target("llvm.test.vectorelement")> %x, ptr %a
+  %load = load <2 x target("llvm.test.vectorelement")>, ptr %a
+  %elt = extractelement <2 x target("llvm.test.vectorelement")> %load, i64 0
+  %res = insertelement <2 x target("llvm.test.vectorelement")> undef, target("llvm.test.vectorelement") %elt, i64 1
+  ret <2 x target("llvm.test.vectorelement")> %res
+}



More information about the llvm-commits mailing list