[llvm] [VFABI] Add support for vector functions that return struct types (PR #119000)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 12 09:19:49 PST 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/119000

>From 16afa09ae322b88ad452a08fff984463eb306a2a Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 6 Dec 2024 14:26:53 +0000
Subject: [PATCH 1/2] [VFABI] Add support for vector functions that return
 struct types

This patch updates the `VFABIDemangler` to support vector functions that
return struct types. For example, a vector variant of `sincos` that
returns a vector of sine values and a vector of cosine values within a
struct.

This patch also adds some helpers in `llvm/IR/CallWideningUtils.h` for
widening call return types. Some of these are used in the
`VFABIDemangler`, and others will be used in subsequent patches, so this
patch simply adds tests for them.
---
 llvm/include/llvm/Analysis/VectorUtils.h    |  14 +-
 llvm/include/llvm/IR/CallWideningUtils.h    |  44 ++++++
 llvm/include/llvm/IR/VectorUtils.h          |  32 +++++
 llvm/lib/IR/CMakeLists.txt                  |   1 +
 llvm/lib/IR/CallWideningUtils.cpp           |  73 ++++++++++
 llvm/lib/IR/VFABIDemangler.cpp              |  18 ++-
 llvm/unittests/IR/CMakeLists.txt            |   1 +
 llvm/unittests/IR/CallWideningUtilsTest.cpp | 149 ++++++++++++++++++++
 llvm/unittests/IR/VFABIDemanglerTest.cpp    |  67 +++++++++
 9 files changed, 379 insertions(+), 20 deletions(-)
 create mode 100644 llvm/include/llvm/IR/CallWideningUtils.h
 create mode 100644 llvm/include/llvm/IR/VectorUtils.h
 create mode 100644 llvm/lib/IR/CallWideningUtils.cpp
 create mode 100644 llvm/unittests/IR/CallWideningUtilsTest.cpp

diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index c1016dd7bdddbd..5433231a1018e9 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -18,6 +18,7 @@
 #include "llvm/Analysis/LoopAccessAnalysis.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/VFABIDemangler.h"
+#include "llvm/IR/VectorUtils.h"
 #include "llvm/Support/CheckedArithmetic.h"
 
 namespace llvm {
@@ -127,19 +128,6 @@ namespace Intrinsic {
 typedef unsigned ID;
 }
 
-/// A helper function for converting Scalar types to vector types. If
-/// the incoming type is void, we return void. If the EC represents a
-/// scalar, we return the scalar type.
-inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
-  if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
-    return Scalar;
-  return VectorType::get(Scalar, EC);
-}
-
-inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
-  return ToVectorTy(Scalar, ElementCount::getFixed(VF));
-}
-
 /// Identify if the intrinsic is trivially vectorizable.
 /// This method returns true if the intrinsic's argument types are all scalars
 /// for the scalar form of the intrinsic and all vectors (or scalars handled by
diff --git a/llvm/include/llvm/IR/CallWideningUtils.h b/llvm/include/llvm/IR/CallWideningUtils.h
new file mode 100644
index 00000000000000..de51c8f6c6ba1f
--- /dev/null
+++ b/llvm/include/llvm/IR/CallWideningUtils.h
@@ -0,0 +1,44 @@
+//===---- CallWideningUtils.h - Utils for widening scalar to vector calls --==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_IR_CALLWIDENINGUTILS_H
+#define LLVM_IR_CALLWIDENINGUTILS_H
+
+#include "llvm/IR/DerivedTypes.h"
+
+namespace llvm {
+
+/// A helper for converting to wider (vector) types. For scalar types, this is
+/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
+/// struct where each element type has been widened to a vector type. Note: Only
+/// unpacked literal struct types are supported.
+Type *ToWideTy(Type *Ty, ElementCount EC);
+
+/// A helper for converting wide types to narrow (non-vector) types. For vector
+/// types, this is equivalent to calling .getScalarType(). For struct types,
+/// this returns a new struct where each element type has been converted to a
+/// scalar type. Note: Only unpacked literal struct types are supported.
+Type *ToNarrowTy(Type *Ty);
+
+/// Returns the types contained in `Ty`. For struct types, it returns the
+/// elements, all other types are returned directly.
+SmallVector<Type *, 2> getContainedTypes(Type *Ty);
+
+/// Returns true if `Ty` is a vector type or a struct of vector types where all
+/// vector types share the same VF.
+bool isWideTy(Type *Ty);
+
+/// Returns the vectorization factor for a widened type.
+inline ElementCount getWideTypeVF(Type *Ty) {
+  assert(isWideTy(Ty) && "expected widened type");
+  return cast<VectorType>(getContainedTypes(Ty).front())->getElementCount();
+}
+
+} // namespace llvm
+
+#endif
diff --git a/llvm/include/llvm/IR/VectorUtils.h b/llvm/include/llvm/IR/VectorUtils.h
new file mode 100644
index 00000000000000..a2e34a02e08a67
--- /dev/null
+++ b/llvm/include/llvm/IR/VectorUtils.h
@@ -0,0 +1,32 @@
+//===----------- VectorUtils.h -  Vector type utility functions -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_IR_VECTORUTILS_H
+#define LLVM_IR_VECTORUTILS_H
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/DerivedTypes.h"
+
+namespace llvm {
+
+/// A helper function for converting Scalar types to vector types. If
+/// the incoming type is void, we return void. If the EC represents a
+/// scalar, we return the scalar type.
+inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
+  if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
+    return Scalar;
+  return VectorType::get(Scalar, EC);
+}
+
+inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
+  return ToVectorTy(Scalar, ElementCount::getFixed(VF));
+}
+
+} // namespace llvm
+
+#endif
diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt
index 544f4ea9223d0e..874ee8c4795a96 100644
--- a/llvm/lib/IR/CMakeLists.txt
+++ b/llvm/lib/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_llvm_component_library(LLVMCore
   AutoUpgrade.cpp
   BasicBlock.cpp
   BuiltinGCs.cpp
+  CallWideningUtils.cpp
   Comdat.cpp
   ConstantFold.cpp
   ConstantFPRange.cpp
diff --git a/llvm/lib/IR/CallWideningUtils.cpp b/llvm/lib/IR/CallWideningUtils.cpp
new file mode 100644
index 00000000000000..ec0bc6e3baa463
--- /dev/null
+++ b/llvm/lib/IR/CallWideningUtils.cpp
@@ -0,0 +1,73 @@
+//===----------- VectorUtils.cpp - Vector type utility functions ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/IR/CallWideningUtils.h"
+#include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/IR/VectorUtils.h"
+
+using namespace llvm;
+
+/// A helper for converting to wider (vector) types. For scalar types, this is
+/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
+/// struct where each element type has been widened to a vector type. Note: Only
+/// unpacked literal struct types are supported.
+Type *llvm::ToWideTy(Type *Ty, ElementCount EC) {
+  if (EC.isScalar())
+    return Ty;
+  auto *StructTy = dyn_cast<StructType>(Ty);
+  if (!StructTy)
+    return ToVectorTy(Ty, EC);
+  assert(StructTy->isLiteral() && !StructTy->isPacked() &&
+         "expected unpacked struct literal");
+  return StructType::get(
+      Ty->getContext(),
+      map_to_vector(StructTy->elements(), [&](Type *ElTy) -> Type * {
+        return VectorType::get(ElTy, EC);
+      }));
+}
+
+/// A helper for converting wide types to narrow (non-vector) types. For vector
+/// types, this is equivalent to calling .getScalarType(). For struct types,
+/// this returns a new struct where each element type has been converted to a
+/// scalar type. Note: Only unpacked literal struct types are supported.
+Type *llvm::ToNarrowTy(Type *Ty) {
+  auto *StructTy = dyn_cast<StructType>(Ty);
+  if (!StructTy)
+    return Ty->getScalarType();
+  assert(StructTy->isLiteral() && !StructTy->isPacked() &&
+         "expected unpacked struct literal");
+  return StructType::get(
+      Ty->getContext(),
+      map_to_vector(StructTy->elements(), [](Type *ElTy) -> Type * {
+        return ElTy->getScalarType();
+      }));
+}
+
+/// Returns the types contained in `Ty`. For struct types, it returns the
+/// elements, all other types are returned directly.
+SmallVector<Type *, 2> llvm::getContainedTypes(Type *Ty) {
+  auto *StructTy = dyn_cast<StructType>(Ty);
+  if (StructTy)
+    return to_vector<2>(StructTy->elements());
+  return {Ty};
+}
+
+/// Returns true if `Ty` is a vector type or a struct of vector types where all
+/// vector types share the same VF.
+bool llvm::isWideTy(Type *Ty) {
+  auto *StructTy = dyn_cast<StructType>(Ty);
+  if (StructTy && (!StructTy->isLiteral() || StructTy->isPacked()))
+    return false;
+  auto ContainedTys = getContainedTypes(Ty);
+  if (ContainedTys.empty() || !ContainedTys.front()->isVectorTy())
+    return false;
+  ElementCount VF = cast<VectorType>(ContainedTys.front())->getElementCount();
+  return all_of(ContainedTys, [&](Type *Ty) {
+    return Ty->isVectorTy() && cast<VectorType>(Ty)->getElementCount() == VF;
+  });
+}
diff --git a/llvm/lib/IR/VFABIDemangler.cpp b/llvm/lib/IR/VFABIDemangler.cpp
index 897583084bf38c..19c922c8bf035b 100644
--- a/llvm/lib/IR/VFABIDemangler.cpp
+++ b/llvm/lib/IR/VFABIDemangler.cpp
@@ -10,6 +10,7 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringSwitch.h"
+#include "llvm/IR/CallWideningUtils.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -346,12 +347,15 @@ getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA,
   // Also check the return type if not void.
   Type *RetTy = Signature->getReturnType();
   if (!RetTy->isVoidTy()) {
-    std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
-    // If we have an unknown scalar element type we can't find a reasonable VF.
-    if (!ReturnEC)
-      return std::nullopt;
-    if (ElementCount::isKnownLT(*ReturnEC, MinEC))
-      MinEC = *ReturnEC;
+    for (Type *RetTy : getContainedTypes(RetTy)) {
+      std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
+      // If we have an unknown scalar element type we can't find a reasonable
+      // VF.
+      if (!ReturnEC)
+        return std::nullopt;
+      if (ElementCount::isKnownLT(*ReturnEC, MinEC))
+        MinEC = *ReturnEC;
+    }
   }
 
   // The SVE Vector function call ABI bases the VF on the widest element types
@@ -566,7 +570,7 @@ FunctionType *VFABI::createFunctionType(const VFInfo &Info,
 
   auto *RetTy = ScalarFTy->getReturnType();
   if (!RetTy->isVoidTy())
-    RetTy = VectorType::get(RetTy, VF);
+    RetTy = ToWideTy(RetTy, VF);
   return FunctionType::get(RetTy, VecTypes, false);
 }
 
diff --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt
index ed93ee547d2231..c4cc1bd78403a8 100644
--- a/llvm/unittests/IR/CMakeLists.txt
+++ b/llvm/unittests/IR/CMakeLists.txt
@@ -15,6 +15,7 @@ add_llvm_unittest(IRTests
   AttributesTest.cpp
   BasicBlockTest.cpp
   BasicBlockDbgInfoTest.cpp
+  CallWideningUtilsTest.cpp
   CFGBuilder.cpp
   ConstantFPRangeTest.cpp
   ConstantRangeTest.cpp
diff --git a/llvm/unittests/IR/CallWideningUtilsTest.cpp b/llvm/unittests/IR/CallWideningUtilsTest.cpp
new file mode 100644
index 00000000000000..939806212eee6a
--- /dev/null
+++ b/llvm/unittests/IR/CallWideningUtilsTest.cpp
@@ -0,0 +1,149 @@
+//===------- CallWideningUtilsTest.cpp - Call widening utils tests --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/IR/CallWideningUtils.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/LLVMContext.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+class CallWideningUtilsTest : public ::testing::Test {};
+
+TEST(CallWideningUtilsTest, TestToWideTy) {
+  LLVMContext C;
+
+  Type *ITy = Type::getInt32Ty(C);
+  Type *FTy = Type::getFloatTy(C);
+  Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
+  Type *MixedStructTy = StructType::get(FTy, ITy);
+  Type *VoidTy = Type::getVoidTy(C);
+
+  for (ElementCount VF :
+       {ElementCount::getFixed(4), ElementCount::getScalable(2)}) {
+    Type *IntVec = ToWideTy(ITy, VF);
+    EXPECT_TRUE(isa<VectorType>(IntVec));
+    EXPECT_EQ(IntVec, VectorType::get(ITy, VF));
+
+    Type *FloatVec = ToWideTy(FTy, VF);
+    EXPECT_TRUE(isa<VectorType>(FloatVec));
+    EXPECT_EQ(FloatVec, VectorType::get(FTy, VF));
+
+    Type *WideHomogeneousStructTy = ToWideTy(HomogeneousStructTy, VF);
+    EXPECT_TRUE(isa<StructType>(WideHomogeneousStructTy));
+    EXPECT_TRUE(
+        cast<StructType>(WideHomogeneousStructTy)->containsHomogeneousTypes());
+    EXPECT_TRUE(cast<StructType>(WideHomogeneousStructTy)->getNumElements() ==
+                3);
+    EXPECT_TRUE(cast<StructType>(WideHomogeneousStructTy)->getElementType(0) ==
+                VectorType::get(FTy, VF));
+
+    Type *WideMixedStructTy = ToWideTy(MixedStructTy, VF);
+    EXPECT_TRUE(isa<StructType>(WideMixedStructTy));
+    EXPECT_TRUE(cast<StructType>(WideMixedStructTy)->getNumElements() == 2);
+    EXPECT_TRUE(cast<StructType>(WideMixedStructTy)->getElementType(0) ==
+                VectorType::get(FTy, VF));
+    EXPECT_TRUE(cast<StructType>(WideMixedStructTy)->getElementType(1) ==
+                VectorType::get(ITy, VF));
+
+    EXPECT_EQ(ToWideTy(VoidTy, VF), VoidTy);
+  }
+
+  ElementCount ScalarVF = ElementCount::getFixed(1);
+  for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy, VoidTy}) {
+    EXPECT_EQ(ToWideTy(Ty, ScalarVF), Ty);
+  }
+}
+
+TEST(CallWideningUtilsTest, TestToNarrowTy) {
+  LLVMContext C;
+
+  Type *ITy = Type::getInt32Ty(C);
+  Type *FTy = Type::getFloatTy(C);
+  Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
+  Type *MixedStructTy = StructType::get(FTy, ITy);
+  Type *VoidTy = Type::getVoidTy(C);
+
+  for (ElementCount VF : {ElementCount::getFixed(1), ElementCount::getFixed(4),
+                          ElementCount::getScalable(2)}) {
+    for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy, VoidTy}) {
+      // ToNarrowTy should be the inverse of ToWideTy.
+      EXPECT_EQ(ToNarrowTy(ToWideTy(Ty, VF)), Ty);
+    };
+  }
+}
+
+TEST(CallWideningUtilsTest, TestGetContainedTypes) {
+  LLVMContext C;
+
+  Type *ITy = Type::getInt32Ty(C);
+  Type *FTy = Type::getFloatTy(C);
+  Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
+  Type *MixedStructTy = StructType::get(FTy, ITy);
+  Type *VoidTy = Type::getVoidTy(C);
+
+  EXPECT_EQ(getContainedTypes(ITy), SmallVector<Type *>({ITy}));
+  EXPECT_EQ(getContainedTypes(FTy), SmallVector<Type *>({FTy}));
+  EXPECT_EQ(getContainedTypes(VoidTy), SmallVector<Type *>({VoidTy}));
+  EXPECT_EQ(getContainedTypes(HomogeneousStructTy),
+            SmallVector<Type *>({FTy, FTy, FTy}));
+  EXPECT_EQ(getContainedTypes(MixedStructTy), SmallVector<Type *>({FTy, ITy}));
+}
+
+TEST(CallWideningUtilsTest, TestIsWideTy) {
+  LLVMContext C;
+
+  Type *ITy = Type::getInt32Ty(C);
+  Type *FTy = Type::getFloatTy(C);
+  Type *NarrowStruct = StructType::get(FTy, ITy);
+  Type *VoidTy = Type::getVoidTy(C);
+
+  EXPECT_FALSE(isWideTy(ITy));
+  EXPECT_FALSE(isWideTy(NarrowStruct));
+  EXPECT_FALSE(isWideTy(VoidTy));
+
+  ElementCount VF = ElementCount::getFixed(4);
+  EXPECT_TRUE(isWideTy(ToWideTy(ITy, VF)));
+  EXPECT_TRUE(isWideTy(ToWideTy(NarrowStruct, VF)));
+
+  Type *MixedVFStruct =
+      StructType::get(VectorType::get(ITy, ElementCount::getFixed(2)),
+                      VectorType::get(ITy, ElementCount::getFixed(4)));
+  EXPECT_FALSE(isWideTy(MixedVFStruct));
+
+  // Currently only literals types are considered wide.
+  Type *NamedWideStruct = StructType::create("Named", VectorType::get(ITy, VF),
+                                             VectorType::get(ITy, VF));
+  EXPECT_FALSE(isWideTy(NamedWideStruct));
+
+  // Currently only unpacked types are considered wide.
+  Type *PackedWideStruct = StructType::get(
+      C, ArrayRef<Type *>{VectorType::get(ITy, VF), VectorType::get(ITy, VF)},
+      /*isPacked=*/true);
+  EXPECT_FALSE(isWideTy(PackedWideStruct));
+}
+
+TEST(CallWideningUtilsTest, TestGetWideTypeVF) {
+  LLVMContext C;
+
+  Type *ITy = Type::getInt32Ty(C);
+  Type *FTy = Type::getFloatTy(C);
+  Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
+  Type *MixedStructTy = StructType::get(FTy, ITy);
+
+  for (ElementCount VF :
+       {ElementCount::getFixed(4), ElementCount::getScalable(2)}) {
+    for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy}) {
+      EXPECT_EQ(getWideTypeVF(ToWideTy(Ty, VF)), VF);
+    };
+  }
+}
+
+} // namespace
diff --git a/llvm/unittests/IR/VFABIDemanglerTest.cpp b/llvm/unittests/IR/VFABIDemanglerTest.cpp
index 07bff16df49335..896cd48ad11d6d 100644
--- a/llvm/unittests/IR/VFABIDemanglerTest.cpp
+++ b/llvm/unittests/IR/VFABIDemanglerTest.cpp
@@ -753,6 +753,73 @@ TEST_F(VFABIParserTest, ParseVoidReturnTypeSVE) {
   EXPECT_EQ(VectorName, "vector_foo");
 }
 
+TEST_F(VFABIParserTest, ParseWideStructReturnTypeSVE) {
+  EXPECT_TRUE(
+      invokeParser("_ZGVsMxv_foo(vector_foo)", "{double, double}(float)"));
+  EXPECT_EQ(ISA, VFISAKind::SVE);
+  EXPECT_TRUE(isMasked());
+  FunctionType *FTy = FunctionType::get(
+      StructType::get(
+          VectorType::get(Type::getDoubleTy(Ctx), ElementCount::getScalable(2)),
+          VectorType::get(Type::getDoubleTy(Ctx),
+                          ElementCount::getScalable(2))),
+      {
+          VectorType::get(Type::getFloatTy(Ctx), ElementCount::getScalable(2)),
+          VectorType::get(Type::getInt1Ty(Ctx), ElementCount::getScalable(2)),
+      },
+      false);
+  EXPECT_EQ(getFunctionType(), FTy);
+  EXPECT_EQ(Parameters.size(), 2U);
+  EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector}));
+  EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::GlobalPredicate}));
+  EXPECT_EQ(VF, ElementCount::getScalable(2));
+  EXPECT_EQ(ScalarName, "foo");
+  EXPECT_EQ(VectorName, "vector_foo");
+}
+
+TEST_F(VFABIParserTest, ParseWideStructMixedReturnTypeSVE) {
+  EXPECT_TRUE(invokeParser("_ZGVsMxv_foo(vector_foo)", "{float, i64}(float)"));
+  EXPECT_EQ(ISA, VFISAKind::SVE);
+  EXPECT_TRUE(isMasked());
+  FunctionType *FTy = FunctionType::get(
+      StructType::get(
+          VectorType::get(Type::getFloatTy(Ctx), ElementCount::getScalable(2)),
+          VectorType::get(Type::getInt64Ty(Ctx), ElementCount::getScalable(2))),
+      {
+          VectorType::get(Type::getFloatTy(Ctx), ElementCount::getScalable(2)),
+          VectorType::get(Type::getInt1Ty(Ctx), ElementCount::getScalable(2)),
+      },
+      false);
+  EXPECT_EQ(getFunctionType(), FTy);
+  EXPECT_EQ(Parameters.size(), 2U);
+  EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector}));
+  EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::GlobalPredicate}));
+  EXPECT_EQ(VF, ElementCount::getScalable(2));
+  EXPECT_EQ(ScalarName, "foo");
+  EXPECT_EQ(VectorName, "vector_foo");
+}
+
+TEST_F(VFABIParserTest, ParseWideStructReturnTypeNEON) {
+  EXPECT_TRUE(
+      invokeParser("_ZGVnN2v_foo(vector_foo)", "{float, float}(float)"));
+  EXPECT_EQ(ISA, VFISAKind::AdvancedSIMD);
+  EXPECT_FALSE(isMasked());
+  FunctionType *FTy = FunctionType::get(
+      StructType::get(
+          VectorType::get(Type::getFloatTy(Ctx), ElementCount::getFixed(2)),
+          VectorType::get(Type::getFloatTy(Ctx), ElementCount::getFixed(2))),
+      {
+          VectorType::get(Type::getFloatTy(Ctx), ElementCount::getFixed(2)),
+      },
+      false);
+  EXPECT_EQ(getFunctionType(), FTy);
+  EXPECT_EQ(Parameters.size(), 1U);
+  EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector}));
+  EXPECT_EQ(VF, ElementCount::getFixed(2));
+  EXPECT_EQ(ScalarName, "foo");
+  EXPECT_EQ(VectorName, "vector_foo");
+}
+
 // Make sure we reject unsupported parameter types.
 TEST_F(VFABIParserTest, ParseUnsupportedElementTypeSVE) {
   EXPECT_FALSE(invokeParser("_ZGVsMxv_foo(vector_foo)", "void(i128)"));

>From 349b678f7553f0b25e9dd85cde7230da2db2e75c Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 12 Dec 2024 17:18:37 +0000
Subject: [PATCH 2/2] Fixups

---
 llvm/include/llvm/IR/CallWideningUtils.h      | 44 -----------
 llvm/include/llvm/IR/StructWideningUtils.h    | 32 ++++++++
 llvm/include/llvm/IR/VectorUtils.h            | 46 +++++++++++-
 llvm/lib/IR/CMakeLists.txt                    |  2 +-
 llvm/lib/IR/CallWideningUtils.cpp             | 73 -------------------
 llvm/lib/IR/StructWideningUtils.cpp           | 57 +++++++++++++++
 llvm/lib/IR/VFABIDemangler.cpp                |  2 +-
 llvm/unittests/IR/CMakeLists.txt              |  2 +-
 ...eningUtilsTest.cpp => VectorUtilsTest.cpp} | 26 +++----
 9 files changed, 149 insertions(+), 135 deletions(-)
 delete mode 100644 llvm/include/llvm/IR/CallWideningUtils.h
 create mode 100644 llvm/include/llvm/IR/StructWideningUtils.h
 delete mode 100644 llvm/lib/IR/CallWideningUtils.cpp
 create mode 100644 llvm/lib/IR/StructWideningUtils.cpp
 rename llvm/unittests/IR/{CallWideningUtilsTest.cpp => VectorUtilsTest.cpp} (86%)

diff --git a/llvm/include/llvm/IR/CallWideningUtils.h b/llvm/include/llvm/IR/CallWideningUtils.h
deleted file mode 100644
index de51c8f6c6ba1f..00000000000000
--- a/llvm/include/llvm/IR/CallWideningUtils.h
+++ /dev/null
@@ -1,44 +0,0 @@
-//===---- CallWideningUtils.h - Utils for widening scalar to vector calls --==//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_IR_CALLWIDENINGUTILS_H
-#define LLVM_IR_CALLWIDENINGUTILS_H
-
-#include "llvm/IR/DerivedTypes.h"
-
-namespace llvm {
-
-/// A helper for converting to wider (vector) types. For scalar types, this is
-/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
-/// struct where each element type has been widened to a vector type. Note: Only
-/// unpacked literal struct types are supported.
-Type *ToWideTy(Type *Ty, ElementCount EC);
-
-/// A helper for converting wide types to narrow (non-vector) types. For vector
-/// types, this is equivalent to calling .getScalarType(). For struct types,
-/// this returns a new struct where each element type has been converted to a
-/// scalar type. Note: Only unpacked literal struct types are supported.
-Type *ToNarrowTy(Type *Ty);
-
-/// Returns the types contained in `Ty`. For struct types, it returns the
-/// elements, all other types are returned directly.
-SmallVector<Type *, 2> getContainedTypes(Type *Ty);
-
-/// Returns true if `Ty` is a vector type or a struct of vector types where all
-/// vector types share the same VF.
-bool isWideTy(Type *Ty);
-
-/// Returns the vectorization factor for a widened type.
-inline ElementCount getWideTypeVF(Type *Ty) {
-  assert(isWideTy(Ty) && "expected widened type");
-  return cast<VectorType>(getContainedTypes(Ty).front())->getElementCount();
-}
-
-} // namespace llvm
-
-#endif
diff --git a/llvm/include/llvm/IR/StructWideningUtils.h b/llvm/include/llvm/IR/StructWideningUtils.h
new file mode 100644
index 00000000000000..b1b922f2b5d8be
--- /dev/null
+++ b/llvm/include/llvm/IR/StructWideningUtils.h
@@ -0,0 +1,32 @@
+//===---- StructWideningUtils.h - Utils for widening scalar to vector calls
+//--==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_IR_STRUCTWIDENINGUTILS_H
+#define LLVM_IR_STRUCTWIDENINGUTILS_H
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/DerivedTypes.h"
+
+namespace llvm {
+
+/// A helper for converting structs of scalar types to structs of vector types.
+/// Note: Only unpacked literal struct types are supported.
+Type *ToWideStructTy(StructType *StructTy, ElementCount EC);
+
+/// A helper for converting structs of vector types to structs of scalar types.
+/// Note: Only unpacked literal struct types are supported.
+Type *ToNarrowStructTy(StructType *StructTy);
+
+/// Returns true if `StructTy` is an unpacked literal struct where all elements
+/// are vectors of matching element count. This does not include empty structs.
+bool isWideStructTy(StructType *StructTy);
+
+} // namespace llvm
+
+#endif
diff --git a/llvm/include/llvm/IR/VectorUtils.h b/llvm/include/llvm/IR/VectorUtils.h
index a2e34a02e08a67..89c8dc2911c50d 100644
--- a/llvm/include/llvm/IR/VectorUtils.h
+++ b/llvm/include/llvm/IR/VectorUtils.h
@@ -1,4 +1,4 @@
-//===----------- VectorUtils.h -  Vector type utility functions -*- C++ -*-===//
+//===----------- VectorUtils.h - Vector type utility functions -*- C++ -*-====//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -9,8 +9,8 @@
 #ifndef LLVM_IR_VECTORUTILS_H
 #define LLVM_IR_VECTORUTILS_H
 
-#include "llvm/ADT/SmallVector.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/StructWideningUtils.h"
 
 namespace llvm {
 
@@ -27,6 +27,48 @@ inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
   return ToVectorTy(Scalar, ElementCount::getFixed(VF));
 }
 
+/// A helper for converting to wider (vector) types. For scalar types, this is
+/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
+/// struct where each element type has been widened to a vector type. Note: Only
+/// unpacked literal struct types are supported.
+inline Type *ToWideTy(Type *Ty, ElementCount EC) {
+  if (StructType *StructTy = dyn_cast<StructType>(Ty))
+    return ToWideStructTy(StructTy, EC);
+  return ToVectorTy(Ty, EC);
+}
+
+/// A helper for converting wide types to narrow (non-vector) types. For vector
+/// types, this is equivalent to calling .getScalarType(). For struct types,
+/// this returns a new struct where each element type has been converted to a
+/// scalar type. Note: Only unpacked literal struct types are supported.
+inline Type *ToNarrowTy(Type *Ty) {
+  if (StructType *StructTy = dyn_cast<StructType>(Ty))
+    return ToNarrowStructTy(StructTy);
+  return Ty->getScalarType();
+}
+
+/// Returns true if `Ty` is a vector type or a struct of vector types where all
+/// vector types share the same VF.
+inline bool isWideTy(Type *Ty) {
+  if (StructType *StructTy = dyn_cast<StructType>(Ty))
+    return isWideStructTy(StructTy);
+  return Ty->isVectorTy();
+}
+
+/// Returns the types contained in `Ty`. For struct types, it returns the
+/// elements, all other types are returned directly.
+inline ArrayRef<Type *> getContainedTypes(Type *const &Ty) {
+  if (auto *StructTy = dyn_cast<StructType>(Ty))
+    return StructTy->elements();
+  return ArrayRef<Type *>{&Ty, 1};
+}
+
+/// Returns the vectorization factor for a widened type.
+inline ElementCount getWideTypeVF(Type *Ty) {
+  assert(isWideTy(Ty) && "expected widened type");
+  return cast<VectorType>(getContainedTypes(Ty).front())->getElementCount();
+}
+
 } // namespace llvm
 
 #endif
diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt
index 874ee8c4795a96..42aceadde59227 100644
--- a/llvm/lib/IR/CMakeLists.txt
+++ b/llvm/lib/IR/CMakeLists.txt
@@ -6,7 +6,6 @@ add_llvm_component_library(LLVMCore
   AutoUpgrade.cpp
   BasicBlock.cpp
   BuiltinGCs.cpp
-  CallWideningUtils.cpp
   Comdat.cpp
   ConstantFold.cpp
   ConstantFPRange.cpp
@@ -65,6 +64,7 @@ add_llvm_component_library(LLVMCore
   PseudoProbe.cpp
   ReplaceConstant.cpp
   Statepoint.cpp
+  StructWideningUtils.cpp
   StructuralHash.cpp
   Type.cpp
   TypedPointerType.cpp
diff --git a/llvm/lib/IR/CallWideningUtils.cpp b/llvm/lib/IR/CallWideningUtils.cpp
deleted file mode 100644
index ec0bc6e3baa463..00000000000000
--- a/llvm/lib/IR/CallWideningUtils.cpp
+++ /dev/null
@@ -1,73 +0,0 @@
-//===----------- VectorUtils.cpp - Vector type utility functions ----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "llvm/IR/CallWideningUtils.h"
-#include "llvm/ADT/SmallVectorExtras.h"
-#include "llvm/IR/VectorUtils.h"
-
-using namespace llvm;
-
-/// A helper for converting to wider (vector) types. For scalar types, this is
-/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
-/// struct where each element type has been widened to a vector type. Note: Only
-/// unpacked literal struct types are supported.
-Type *llvm::ToWideTy(Type *Ty, ElementCount EC) {
-  if (EC.isScalar())
-    return Ty;
-  auto *StructTy = dyn_cast<StructType>(Ty);
-  if (!StructTy)
-    return ToVectorTy(Ty, EC);
-  assert(StructTy->isLiteral() && !StructTy->isPacked() &&
-         "expected unpacked struct literal");
-  return StructType::get(
-      Ty->getContext(),
-      map_to_vector(StructTy->elements(), [&](Type *ElTy) -> Type * {
-        return VectorType::get(ElTy, EC);
-      }));
-}
-
-/// A helper for converting wide types to narrow (non-vector) types. For vector
-/// types, this is equivalent to calling .getScalarType(). For struct types,
-/// this returns a new struct where each element type has been converted to a
-/// scalar type. Note: Only unpacked literal struct types are supported.
-Type *llvm::ToNarrowTy(Type *Ty) {
-  auto *StructTy = dyn_cast<StructType>(Ty);
-  if (!StructTy)
-    return Ty->getScalarType();
-  assert(StructTy->isLiteral() && !StructTy->isPacked() &&
-         "expected unpacked struct literal");
-  return StructType::get(
-      Ty->getContext(),
-      map_to_vector(StructTy->elements(), [](Type *ElTy) -> Type * {
-        return ElTy->getScalarType();
-      }));
-}
-
-/// Returns the types contained in `Ty`. For struct types, it returns the
-/// elements, all other types are returned directly.
-SmallVector<Type *, 2> llvm::getContainedTypes(Type *Ty) {
-  auto *StructTy = dyn_cast<StructType>(Ty);
-  if (StructTy)
-    return to_vector<2>(StructTy->elements());
-  return {Ty};
-}
-
-/// Returns true if `Ty` is a vector type or a struct of vector types where all
-/// vector types share the same VF.
-bool llvm::isWideTy(Type *Ty) {
-  auto *StructTy = dyn_cast<StructType>(Ty);
-  if (StructTy && (!StructTy->isLiteral() || StructTy->isPacked()))
-    return false;
-  auto ContainedTys = getContainedTypes(Ty);
-  if (ContainedTys.empty() || !ContainedTys.front()->isVectorTy())
-    return false;
-  ElementCount VF = cast<VectorType>(ContainedTys.front())->getElementCount();
-  return all_of(ContainedTys, [&](Type *Ty) {
-    return Ty->isVectorTy() && cast<VectorType>(Ty)->getElementCount() == VF;
-  });
-}
diff --git a/llvm/lib/IR/StructWideningUtils.cpp b/llvm/lib/IR/StructWideningUtils.cpp
new file mode 100644
index 00000000000000..5c325c7cbdafc6
--- /dev/null
+++ b/llvm/lib/IR/StructWideningUtils.cpp
@@ -0,0 +1,57 @@
+//===-- CallWideningUtils.cpp - Utils for widening scalar to vector calls --==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/IR/StructWideningUtils.h"
+#include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/IR/VectorUtils.h"
+
+using namespace llvm;
+
+static bool isUnpackedStructLiteral(StructType *StructTy) {
+  return StructTy->isLiteral() && !StructTy->isPacked();
+}
+
+/// A helper for converting structs of scalar types to structs of vector types.
+/// Note: Only unpacked literal struct types are supported.
+Type *llvm::ToWideStructTy(StructType *StructTy, ElementCount EC) {
+  if (EC.isScalar())
+    return StructTy;
+  assert(isUnpackedStructLiteral(StructTy) &&
+         "expected unpacked struct literal");
+  return StructType::get(
+      StructTy->getContext(),
+      map_to_vector(StructTy->elements(), [&](Type *ElTy) -> Type * {
+        return VectorType::get(ElTy, EC);
+      }));
+}
+
+/// A helper for converting structs of vector types to structs of scalar types.
+/// Note: Only unpacked literal struct types are supported.
+Type *llvm::ToNarrowStructTy(StructType *StructTy) {
+  assert(isUnpackedStructLiteral(StructTy) &&
+         "expected unpacked struct literal");
+  return StructType::get(
+      StructTy->getContext(),
+      map_to_vector(StructTy->elements(), [](Type *ElTy) -> Type * {
+        return ElTy->getScalarType();
+      }));
+}
+
+/// Returns true if `StructTy` is an unpacked literal struct where all elements
+/// are vectors of matching element count. This does not include empty structs.
+bool llvm::isWideStructTy(StructType *StructTy) {
+  if (!isUnpackedStructLiteral(StructTy))
+    return false;
+  auto ElemTys = StructTy->elements();
+  if (ElemTys.empty() || !ElemTys.front()->isVectorTy())
+    return false;
+  ElementCount VF = cast<VectorType>(ElemTys.front())->getElementCount();
+  return all_of(ElemTys, [&](Type *Ty) {
+    return Ty->isVectorTy() && cast<VectorType>(Ty)->getElementCount() == VF;
+  });
+}
diff --git a/llvm/lib/IR/VFABIDemangler.cpp b/llvm/lib/IR/VFABIDemangler.cpp
index 19c922c8bf035b..b7559e2f4108a7 100644
--- a/llvm/lib/IR/VFABIDemangler.cpp
+++ b/llvm/lib/IR/VFABIDemangler.cpp
@@ -10,8 +10,8 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringSwitch.h"
-#include "llvm/IR/CallWideningUtils.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/VectorUtils.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <limits>
diff --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt
index c4cc1bd78403a8..efbe710987a489 100644
--- a/llvm/unittests/IR/CMakeLists.txt
+++ b/llvm/unittests/IR/CMakeLists.txt
@@ -15,7 +15,6 @@ add_llvm_unittest(IRTests
   AttributesTest.cpp
   BasicBlockTest.cpp
   BasicBlockDbgInfoTest.cpp
-  CallWideningUtilsTest.cpp
   CFGBuilder.cpp
   ConstantFPRangeTest.cpp
   ConstantRangeTest.cpp
@@ -53,6 +52,7 @@ add_llvm_unittest(IRTests
   ValueTest.cpp
   VectorBuilderTest.cpp
   VectorTypesTest.cpp
+  VectorUtilsTest.cpp
   VerifierTest.cpp
   VFABIDemanglerTest.cpp
   VPIntrinsicTest.cpp
diff --git a/llvm/unittests/IR/CallWideningUtilsTest.cpp b/llvm/unittests/IR/VectorUtilsTest.cpp
similarity index 86%
rename from llvm/unittests/IR/CallWideningUtilsTest.cpp
rename to llvm/unittests/IR/VectorUtilsTest.cpp
index 939806212eee6a..b70cd2f511df48 100644
--- a/llvm/unittests/IR/CallWideningUtilsTest.cpp
+++ b/llvm/unittests/IR/VectorUtilsTest.cpp
@@ -1,4 +1,4 @@
-//===------- CallWideningUtilsTest.cpp - Call widening utils tests --------===//
+//===------- VectorUtilsTest.cpp - Vector utils tests ---------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "llvm/IR/CallWideningUtils.h"
+#include "llvm/IR/VectorUtils.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/LLVMContext.h"
 #include "gtest/gtest.h"
@@ -15,9 +15,9 @@ using namespace llvm;
 
 namespace {
 
-class CallWideningUtilsTest : public ::testing::Test {};
+class VectorUtilsTest : public ::testing::Test {};
 
-TEST(CallWideningUtilsTest, TestToWideTy) {
+TEST(VectorUtilsTest, TestToWideTy) {
   LLVMContext C;
 
   Type *ITy = Type::getInt32Ty(C);
@@ -62,7 +62,7 @@ TEST(CallWideningUtilsTest, TestToWideTy) {
   }
 }
 
-TEST(CallWideningUtilsTest, TestToNarrowTy) {
+TEST(VectorUtilsTest, TestToNarrowTy) {
   LLVMContext C;
 
   Type *ITy = Type::getInt32Ty(C);
@@ -80,7 +80,7 @@ TEST(CallWideningUtilsTest, TestToNarrowTy) {
   }
 }
 
-TEST(CallWideningUtilsTest, TestGetContainedTypes) {
+TEST(VectorUtilsTest, TestGetContainedTypes) {
   LLVMContext C;
 
   Type *ITy = Type::getInt32Ty(C);
@@ -89,15 +89,15 @@ TEST(CallWideningUtilsTest, TestGetContainedTypes) {
   Type *MixedStructTy = StructType::get(FTy, ITy);
   Type *VoidTy = Type::getVoidTy(C);
 
-  EXPECT_EQ(getContainedTypes(ITy), SmallVector<Type *>({ITy}));
-  EXPECT_EQ(getContainedTypes(FTy), SmallVector<Type *>({FTy}));
-  EXPECT_EQ(getContainedTypes(VoidTy), SmallVector<Type *>({VoidTy}));
+  EXPECT_EQ(getContainedTypes(ITy), ArrayRef<Type *>({ITy}));
+  EXPECT_EQ(getContainedTypes(FTy), ArrayRef<Type *>({FTy}));
+  EXPECT_EQ(getContainedTypes(VoidTy), ArrayRef<Type *>({VoidTy}));
   EXPECT_EQ(getContainedTypes(HomogeneousStructTy),
-            SmallVector<Type *>({FTy, FTy, FTy}));
-  EXPECT_EQ(getContainedTypes(MixedStructTy), SmallVector<Type *>({FTy, ITy}));
+            ArrayRef<Type *>({FTy, FTy, FTy}));
+  EXPECT_EQ(getContainedTypes(MixedStructTy), ArrayRef<Type *>({FTy, ITy}));
 }
 
-TEST(CallWideningUtilsTest, TestIsWideTy) {
+TEST(VectorUtilsTest, TestIsWideTy) {
   LLVMContext C;
 
   Type *ITy = Type::getInt32Ty(C);
@@ -130,7 +130,7 @@ TEST(CallWideningUtilsTest, TestIsWideTy) {
   EXPECT_FALSE(isWideTy(PackedWideStruct));
 }
 
-TEST(CallWideningUtilsTest, TestGetWideTypeVF) {
+TEST(VectorUtilsTest, TestGetWideTypeVF) {
   LLVMContext C;
 
   Type *ITy = Type::getInt32Ty(C);



More information about the llvm-commits mailing list