[llvm] c954986 - [GISel] Add support for scalable vectors in getGCDType (#80307)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 7 07:32:16 PST 2024


Author: Michael Maitland
Date: 2024-02-07T10:32:12-05:00
New Revision: c954986fec97ab22a9658b496731d0c280938a64

URL: https://github.com/llvm/llvm-project/commit/c954986fec97ab22a9658b496731d0c280938a64
DIFF: https://github.com/llvm/llvm-project/commit/c954986fec97ab22a9658b496731d0c280938a64.diff

LOG: [GISel] Add support for scalable vectors in getGCDType (#80307)

This function can be called from buildCopyToRegs where at least one of
the types is a scalable vector type. This function crashed because it
did not know how to handle scalable vector types.

This patch extends the functionality of getGCDType to handle when at
least one of the types is a scalable vector. getGCDType between a fixed
and scalable vector is not implemented since the docstring of the
function explains that getGCDType is used to build MERGE/UNMERGE
instructions and we will never build a MERGE/UNMERGE between fixed and
scalable vectors.

---------

Co-authored-by: Matt Arsenault <arsenm2 at gmail.com>

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/GlobalISel/Utils.h
    llvm/lib/CodeGen/GlobalISel/Utils.cpp
    llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
index c96e4217d21f0..f8900f3434cca 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -368,7 +368,10 @@ LLT getCoverTy(LLT OrigTy, LLT TargetTy);
 /// If these are vectors with 
diff erent element types, this will try to produce
 /// a vector with a compatible total size, but the element type of \p OrigTy. If
 /// this can't be satisfied, this will produce a scalar smaller than the
-/// original vector elements.
+/// original vector elements. It is an error to call this function where
+/// one argument is a fixed vector and the other is a scalable vector, since it
+/// is illegal to build a G_{MERGE|UNMERGE}_VALUES between fixed and scalable
+/// vectors.
 ///
 /// In the worst case, this returns LLT::scalar(1)
 LLVM_READNONE

diff  --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index dd99381093b6a..26fd12f9e51c4 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1159,45 +1159,56 @@ LLT llvm::getCoverTy(LLT OrigTy, LLT TargetTy) {
 }
 
 LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
-  const unsigned OrigSize = OrigTy.getSizeInBits();
-  const unsigned TargetSize = TargetTy.getSizeInBits();
-
-  if (OrigSize == TargetSize)
+  if (OrigTy.getSizeInBits() == TargetTy.getSizeInBits())
     return OrigTy;
 
-  if (OrigTy.isVector()) {
+  if (OrigTy.isVector() && TargetTy.isVector()) {
     LLT OrigElt = OrigTy.getElementType();
-    if (TargetTy.isVector()) {
-      LLT TargetElt = TargetTy.getElementType();
-      if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) {
-        int GCD = std::gcd(OrigTy.getNumElements(), TargetTy.getNumElements());
-        return LLT::scalarOrVector(ElementCount::getFixed(GCD), OrigElt);
-      }
-    } else {
-      // If the source is a vector of pointers, return a pointer element.
-      if (OrigElt.getSizeInBits() == TargetSize)
-        return OrigElt;
-    }
 
-    unsigned GCD = std::gcd(OrigSize, TargetSize);
+    // TODO: The docstring for this function says the intention is to use this
+    // function to build MERGE/UNMERGE instructions. It won't be the case that
+    // we generate a MERGE/UNMERGE between fixed and scalable vector types. We
+    // could implement getGCDType between the two in the future if there was a
+    // need, but it is not worth it now as this function should not be used in
+    // that way.
+    assert(((OrigTy.isScalableVector() && !TargetTy.isFixedVector()) ||
+            (OrigTy.isFixedVector() && !TargetTy.isScalableVector())) &&
+           "getGCDType not implemented between fixed and scalable vectors.");
+
+    unsigned GCD = std::gcd(OrigTy.getSizeInBits().getKnownMinValue(),
+                            TargetTy.getSizeInBits().getKnownMinValue());
     if (GCD == OrigElt.getSizeInBits())
-      return OrigElt;
+      return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
+                                 OrigElt);
 
-    // If we can't produce the original element type, we have to use a smaller
-    // scalar.
+    // Cannot produce original element type, but both have vscale in common.
     if (GCD < OrigElt.getSizeInBits())
-      return LLT::scalar(GCD);
-    return LLT::fixed_vector(GCD / OrigElt.getSizeInBits(), OrigElt);
-  }
+      return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
+                                 GCD);
 
-  if (TargetTy.isVector()) {
-    // Try to preserve the original element type.
-    LLT TargetElt = TargetTy.getElementType();
-    if (TargetElt.getSizeInBits() == OrigSize)
-      return OrigTy;
+    return LLT::vector(
+        ElementCount::get(GCD / OrigElt.getSizeInBits().getFixedValue(),
+                          OrigTy.isScalable()),
+        OrigElt);
   }
 
-  unsigned GCD = std::gcd(OrigSize, TargetSize);
+  // If one type is vector and the element size matches the scalar size, then
+  // the gcd is the scalar type.
+  if (OrigTy.isVector() &&
+      OrigTy.getElementType().getSizeInBits() == TargetTy.getSizeInBits())
+    return OrigTy.getElementType();
+  if (TargetTy.isVector() &&
+      TargetTy.getElementType().getSizeInBits() == OrigTy.getSizeInBits())
+    return OrigTy;
+
+  // At this point, both types are either scalars of 
diff erent type or one is a
+  // vector and one is a scalar. If both types are scalars, the GCD type is the
+  // GCD between the two scalar sizes. If one is vector and one is scalar, then
+  // the GCD type is the GCD between the scalar and the vector element size.
+  LLT OrigScalar = OrigTy.getScalarType();
+  LLT TargetScalar = TargetTy.getScalarType();
+  unsigned GCD = std::gcd(OrigScalar.getSizeInBits().getFixedValue(),
+                          TargetScalar.getSizeInBits().getFixedValue());
   return LLT::scalar(GCD);
 }
 

diff  --git a/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp b/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp
index 92bd0a36b82b4..1ff7fd956d015 100644
--- a/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp
@@ -183,6 +183,62 @@ TEST(GISelUtilsTest, getGCDType) {
 
   EXPECT_EQ(LLT::scalar(4), getGCDType(LLT::fixed_vector(3, 4), S8));
   EXPECT_EQ(LLT::scalar(4), getGCDType(S8, LLT::fixed_vector(3, 4)));
+
+  // Scalable -> Scalable
+  EXPECT_EQ(NXV1S1, getGCDType(NXV1S1, NXV1S32));
+  EXPECT_EQ(NXV1S32, getGCDType(NXV1S64, NXV1S32));
+  EXPECT_EQ(NXV1S32, getGCDType(NXV1S32, NXV1S64));
+  EXPECT_EQ(NXV1P0, getGCDType(NXV1P0, NXV1S64));
+  EXPECT_EQ(NXV1S64, getGCDType(NXV1S64, NXV1P0));
+
+  EXPECT_EQ(NXV4S1, getGCDType(NXV4S1, NXV4S32));
+  EXPECT_EQ(NXV2S64, getGCDType(NXV4S64, NXV4S32));
+  EXPECT_EQ(NXV4S32, getGCDType(NXV4S32, NXV4S64));
+  EXPECT_EQ(NXV4P0, getGCDType(NXV4P0, NXV4S64));
+  EXPECT_EQ(NXV4S64, getGCDType(NXV4S64, NXV4P0));
+
+  EXPECT_EQ(NXV4S1, getGCDType(NXV4S1, NXV2S32));
+  EXPECT_EQ(NXV1S64, getGCDType(NXV4S64, NXV2S32));
+  EXPECT_EQ(NXV4S32, getGCDType(NXV4S32, NXV2S64));
+  EXPECT_EQ(NXV2P0, getGCDType(NXV4P0, NXV2S64));
+  EXPECT_EQ(NXV2S64, getGCDType(NXV4S64, NXV2P0));
+
+  EXPECT_EQ(NXV2S1, getGCDType(NXV2S1, NXV4S32));
+  EXPECT_EQ(NXV2S64, getGCDType(NXV2S64, NXV4S32));
+  EXPECT_EQ(NXV2S32, getGCDType(NXV2S32, NXV4S64));
+  EXPECT_EQ(NXV2P0, getGCDType(NXV2P0, NXV4S64));
+  EXPECT_EQ(NXV2S64, getGCDType(NXV2S64, NXV4P0));
+
+  EXPECT_EQ(NXV1S1, getGCDType(NXV3S1, NXV4S32));
+  EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4S32));
+  EXPECT_EQ(NXV1S32, getGCDType(NXV3S32, NXV4S64));
+  EXPECT_EQ(NXV1P0, getGCDType(NXV3P0, NXV4S64));
+  EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4P0));
+
+  EXPECT_EQ(NXV1S1, getGCDType(NXV3S1, NXV4S1));
+  EXPECT_EQ(NXV1S32, getGCDType(NXV3S32, NXV4S32));
+  EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4S64));
+  EXPECT_EQ(NXV1P0, getGCDType(NXV3P0, NXV4P0));
+
+  // Scalable, Scalar
+
+  EXPECT_EQ(S1, getGCDType(NXV1S1, S1));
+  EXPECT_EQ(S1, getGCDType(NXV1S1, S32));
+  EXPECT_EQ(S1, getGCDType(NXV1S32, S1));
+  EXPECT_EQ(S32, getGCDType(NXV1S32, S32));
+  EXPECT_EQ(S32, getGCDType(NXV1S32, S64));
+  EXPECT_EQ(S1, getGCDType(NXV2S32, S1));
+  EXPECT_EQ(S32, getGCDType(NXV2S32, S32));
+  EXPECT_EQ(S32, getGCDType(NXV2S32, S64));
+
+  EXPECT_EQ(S1, getGCDType(S1, NXV1S1));
+  EXPECT_EQ(S1, getGCDType(S32, NXV1S1));
+  EXPECT_EQ(S1, getGCDType(S1, NXV1S32));
+  EXPECT_EQ(S32, getGCDType(S32, NXV1S32));
+  EXPECT_EQ(S32, getGCDType(S64, NXV1S32));
+  EXPECT_EQ(S1, getGCDType(S1, NXV2S32));
+  EXPECT_EQ(S32, getGCDType(S32, NXV2S32));
+  EXPECT_EQ(S32, getGCDType(S64, NXV2S32));
 }
 
 TEST(GISelUtilsTest, getLCMType) {


        


More information about the llvm-commits mailing list