[llvm] [GISel] Add support for scalable vectors in getGCDType (PR #80307)

Michael Maitland via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 6 07:27:29 PST 2024


https://github.com/michaelmaitland updated https://github.com/llvm/llvm-project/pull/80307

>From ce16521c567fbf67fcb6f780ed3cccdee7783513 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Wed, 31 Jan 2024 11:51:26 -0800
Subject: [PATCH 1/5] [GISel] Add support for scalable vectors in getGCDType

This function can be called from buildCopyToRegs where at least one of the types
is a scalable vector type. This function crashed because it did not know
how to handle scalable vector types.

This patch extends the functionality of getGCDType to handle when at
least one of the types is a scalable vector. getGCDType between a fixed
and scalable vector is not implemented since the docstring of the
function explains that getGCDType is used to build MERGE/UNMERGE
instructions and we will never build a MERGE/UNMERGE between fixed and scalable
vectors.
---
 llvm/include/llvm/CodeGen/GlobalISel/Utils.h  |  5 +-
 llvm/lib/CodeGen/GlobalISel/Utils.cpp         | 79 +++++++++++--------
 .../CodeGen/GlobalISel/GISelUtilsTest.cpp     | 76 ++++++++++++++++++
 3 files changed, 127 insertions(+), 33 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
index bf02911e19351..a4fcc28d5af39 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -365,7 +365,10 @@ LLT getCoverTy(LLT OrigTy, LLT TargetTy);
 /// If these are vectors with different element types, this will try to produce
 /// a vector with a compatible total size, but the element type of \p OrigTy. If
 /// this can't be satisfied, this will produce a scalar smaller than the
-/// original vector elements.
+/// original vector elements. It is an error to call this function where
+/// one argument is a fixed vector and the other is a scalable vector, since it
+/// is illegal to build a G_{MERGE|UNMERGE}_VALUES between fixed and scalable
+/// vectors.
 ///
 /// In the worst case, this returns LLT::scalar(1)
 LLVM_READNONE
diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index aed826a9cbc54..0f2def47ec1e7 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1132,45 +1132,60 @@ LLT llvm::getCoverTy(LLT OrigTy, LLT TargetTy) {
 }
 
 LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
-  const unsigned OrigSize = OrigTy.getSizeInBits();
-  const unsigned TargetSize = TargetTy.getSizeInBits();
-
-  if (OrigSize == TargetSize)
+  if (OrigTy.getSizeInBits() == TargetTy.getSizeInBits())
     return OrigTy;
 
-  if (OrigTy.isVector()) {
+  if (OrigTy.isVector() && TargetTy.isVector()) {
     LLT OrigElt = OrigTy.getElementType();
-    if (TargetTy.isVector()) {
-      LLT TargetElt = TargetTy.getElementType();
-      if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) {
-        int GCD = std::gcd(OrigTy.getNumElements(), TargetTy.getNumElements());
-        return LLT::scalarOrVector(ElementCount::getFixed(GCD), OrigElt);
-      }
-    } else {
-      // If the source is a vector of pointers, return a pointer element.
-      if (OrigElt.getSizeInBits() == TargetSize)
-        return OrigElt;
-    }
-
-    unsigned GCD = std::gcd(OrigSize, TargetSize);
-    if (GCD == OrigElt.getSizeInBits())
-      return OrigElt;
+    LLT TargetElt = TargetTy.getElementType();
 
-    // If we can't produce the original element type, we have to use a smaller
-    // scalar.
-    if (GCD < OrigElt.getSizeInBits())
-      return LLT::scalar(GCD);
-    return LLT::fixed_vector(GCD / OrigElt.getSizeInBits(), OrigElt);
+    // TODO: The docstring for this function says the intention is to use this
+    // function to build MERGE/UNMERGE instructions. It won't be the case that
+    // we generate a MERGE/UNMERGE between fixed and scalable vector types. We
+    // could implement getGCDType between the two in the future if there was a
+    // need, but it is not worth it now as this function should not be used in
+    // that way.
+    if ((OrigTy.isScalableVector() && TargetTy.isFixedVector()) ||
+        (OrigTy.isFixedVector() && TargetTy.isScalableVector()))
+      llvm_unreachable(
+          "getGCDType not implemented between fixed and scalable vectors.");
+
+      unsigned GCD = std::gcd(OrigTy.getElementCount().getKnownMinValue() *
+                                  OrigElt.getSizeInBits().getFixedValue(),
+                              TargetTy.getElementCount().getKnownMinValue() *
+                                  TargetElt.getSizeInBits().getFixedValue());
+      if (GCD == OrigElt.getSizeInBits())
+        return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
+                                   OrigElt);
+
+      // Cannot produce original element type, but both have vscale in common.
+      if (GCD < OrigElt.getSizeInBits())
+        return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
+                                   GCD);
+
+      return LLT::vector(
+          ElementCount::get(GCD / OrigElt.getSizeInBits().getFixedValue(),
+                            OrigTy.isScalable()),
+          OrigElt);
   }
 
-  if (TargetTy.isVector()) {
-    // Try to preserve the original element type.
-    LLT TargetElt = TargetTy.getElementType();
-    if (TargetElt.getSizeInBits() == OrigSize)
-      return OrigTy;
-  }
+  // If one type is vector and the element size matches the scalar size, then
+  // the gcd is the scalar type.
+  if (OrigTy.isVector() &&
+      OrigTy.getElementType().getSizeInBits() == TargetTy.getSizeInBits())
+    return OrigTy.getElementType();
+  if (TargetTy.isVector() &&
+      TargetTy.getElementType().getSizeInBits() == OrigTy.getSizeInBits())
+    return OrigTy;
 
-  unsigned GCD = std::gcd(OrigSize, TargetSize);
+  // At this point, both types are either scalars of different type or one is a
+  // vector and one is a scalar. If both types are scalars, the GCD type is the
+  // GCD between the two scalar sizes. If one is vector and one is scalar, then
+  // the GCD type is the GCD between the scalar and the vector element size.
+  LLT OrigScalar = OrigTy.isVector() ? OrigTy.getElementType() : OrigTy;
+  LLT TargetScalar = TargetTy.isVector() ? TargetTy.getElementType() : TargetTy;
+  unsigned GCD = std::gcd(OrigScalar.getSizeInBits().getFixedValue(),
+                          TargetScalar.getSizeInBits().getFixedValue());
   return LLT::scalar(GCD);
 }
 
diff --git a/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp b/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp
index 8fda332d5c054..83cc4f2a7f388 100644
--- a/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/GISelUtilsTest.cpp
@@ -46,6 +46,26 @@ static const LLT V6P0 = LLT::fixed_vector(6, P0);
 static const LLT V2P1 = LLT::fixed_vector(2, P1);
 static const LLT V4P1 = LLT::fixed_vector(4, P1);
 
+static const LLT NXV1S1 = LLT::scalable_vector(1, S1);
+static const LLT NXV2S1 = LLT::scalable_vector(2, S1);
+static const LLT NXV3S1 = LLT::scalable_vector(3, S1);
+static const LLT NXV4S1 = LLT::scalable_vector(4, S1);
+
+static const LLT NXV1S32 = LLT::scalable_vector(1, S32);
+static const LLT NXV2S32 = LLT::scalable_vector(2, S32);
+static const LLT NXV3S32 = LLT::scalable_vector(3, S32);
+static const LLT NXV4S32 = LLT::scalable_vector(4, S32);
+
+static const LLT NXV1S64 = LLT::scalable_vector(1, S64);
+static const LLT NXV2S64 = LLT::scalable_vector(2, S64);
+static const LLT NXV3S64 = LLT::scalable_vector(3, S64);
+static const LLT NXV4S64 = LLT::scalable_vector(4, S64);
+
+static const LLT NXV1P0 = LLT::scalable_vector(1, P0);
+static const LLT NXV2P0 = LLT::scalable_vector(2, P0);
+static const LLT NXV3P0 = LLT::scalable_vector(3, P0);
+static const LLT NXV4P0 = LLT::scalable_vector(4, P0);
+
 TEST(GISelUtilsTest, getGCDType) {
   EXPECT_EQ(S1, getGCDType(S1, S1));
   EXPECT_EQ(S32, getGCDType(S32, S32));
@@ -152,6 +172,62 @@ TEST(GISelUtilsTest, getGCDType) {
 
   EXPECT_EQ(LLT::scalar(4), getGCDType(LLT::fixed_vector(3, 4), S8));
   EXPECT_EQ(LLT::scalar(4), getGCDType(S8, LLT::fixed_vector(3, 4)));
+
+  // Scalable -> Scalable
+  EXPECT_EQ(NXV1S1, getGCDType(NXV1S1, NXV1S32));
+  EXPECT_EQ(NXV1S32, getGCDType(NXV1S64, NXV1S32));
+  EXPECT_EQ(NXV1S32, getGCDType(NXV1S32, NXV1S64));
+  EXPECT_EQ(NXV1P0, getGCDType(NXV1P0, NXV1S64));
+  EXPECT_EQ(NXV1S64, getGCDType(NXV1S64, NXV1P0));
+
+  EXPECT_EQ(NXV4S1, getGCDType(NXV4S1, NXV4S32));
+  EXPECT_EQ(NXV2S64, getGCDType(NXV4S64, NXV4S32));
+  EXPECT_EQ(NXV4S32, getGCDType(NXV4S32, NXV4S64));
+  EXPECT_EQ(NXV4P0, getGCDType(NXV4P0, NXV4S64));
+  EXPECT_EQ(NXV4S64, getGCDType(NXV4S64, NXV4P0));
+
+  EXPECT_EQ(NXV4S1, getGCDType(NXV4S1, NXV2S32));
+  EXPECT_EQ(NXV1S64, getGCDType(NXV4S64, NXV2S32));
+  EXPECT_EQ(NXV4S32, getGCDType(NXV4S32, NXV2S64));
+  EXPECT_EQ(NXV2P0, getGCDType(NXV4P0, NXV2S64));
+  EXPECT_EQ(NXV2S64, getGCDType(NXV4S64, NXV2P0));
+
+  EXPECT_EQ(NXV2S1, getGCDType(NXV2S1, NXV4S32));
+  EXPECT_EQ(NXV2S64, getGCDType(NXV2S64, NXV4S32));
+  EXPECT_EQ(NXV2S32, getGCDType(NXV2S32, NXV4S64));
+  EXPECT_EQ(NXV2P0, getGCDType(NXV2P0, NXV4S64));
+  EXPECT_EQ(NXV2S64, getGCDType(NXV2S64, NXV4P0));
+
+  EXPECT_EQ(NXV1S1, getGCDType(NXV3S1, NXV4S32));
+  EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4S32));
+  EXPECT_EQ(NXV1S32, getGCDType(NXV3S32, NXV4S64));
+  EXPECT_EQ(NXV1P0, getGCDType(NXV3P0, NXV4S64));
+  EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4P0));
+
+  EXPECT_EQ(NXV1S1, getGCDType(NXV3S1, NXV4S1));
+  EXPECT_EQ(NXV1S32, getGCDType(NXV3S32, NXV4S32));
+  EXPECT_EQ(NXV1S64, getGCDType(NXV3S64, NXV4S64));
+  EXPECT_EQ(NXV1P0, getGCDType(NXV3P0, NXV4P0));
+
+  // Scalable, Scalar
+
+  EXPECT_EQ(S1, getGCDType(NXV1S1, S1));
+  EXPECT_EQ(S1, getGCDType(NXV1S1, S32));
+  EXPECT_EQ(S1, getGCDType(NXV1S32, S1));
+  EXPECT_EQ(S32, getGCDType(NXV1S32, S32));
+  EXPECT_EQ(S32, getGCDType(NXV1S32, S64));
+  EXPECT_EQ(S1, getGCDType(NXV2S32, S1));
+  EXPECT_EQ(S32, getGCDType(NXV2S32, S32));
+  EXPECT_EQ(S32, getGCDType(NXV2S32, S64));
+
+  EXPECT_EQ(S1, getGCDType(S1, NXV1S1));
+  EXPECT_EQ(S1, getGCDType(S32, NXV1S1));
+  EXPECT_EQ(S1, getGCDType(S1, NXV1S32));
+  EXPECT_EQ(S32, getGCDType(S32, NXV1S32));
+  EXPECT_EQ(S32, getGCDType(S64, NXV1S32));
+  EXPECT_EQ(S1, getGCDType(S1, NXV2S32));
+  EXPECT_EQ(S32, getGCDType(S32, NXV2S32));
+  EXPECT_EQ(S32, getGCDType(S64, NXV2S32));
 }
 
 TEST(GISelUtilsTest, getLCMType) {

>From 6d891947e65e71b235bd2e50decc523e2d43ac69 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Thu, 1 Feb 2024 08:36:18 -0800
Subject: [PATCH 2/5] fixup! clang-format

---
 llvm/lib/CodeGen/GlobalISel/Utils.cpp | 34 +++++++++++++--------------
 1 file changed, 17 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index 0f2def47ec1e7..8fa3e47c74ac1 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1150,23 +1150,23 @@ LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
       llvm_unreachable(
           "getGCDType not implemented between fixed and scalable vectors.");
 
-      unsigned GCD = std::gcd(OrigTy.getElementCount().getKnownMinValue() *
-                                  OrigElt.getSizeInBits().getFixedValue(),
-                              TargetTy.getElementCount().getKnownMinValue() *
-                                  TargetElt.getSizeInBits().getFixedValue());
-      if (GCD == OrigElt.getSizeInBits())
-        return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
-                                   OrigElt);
-
-      // Cannot produce original element type, but both have vscale in common.
-      if (GCD < OrigElt.getSizeInBits())
-        return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
-                                   GCD);
-
-      return LLT::vector(
-          ElementCount::get(GCD / OrigElt.getSizeInBits().getFixedValue(),
-                            OrigTy.isScalable()),
-          OrigElt);
+    unsigned GCD = std::gcd(OrigTy.getElementCount().getKnownMinValue() *
+                                OrigElt.getSizeInBits().getFixedValue(),
+                            TargetTy.getElementCount().getKnownMinValue() *
+                                TargetElt.getSizeInBits().getFixedValue());
+    if (GCD == OrigElt.getSizeInBits())
+      return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
+                                 OrigElt);
+
+    // Cannot produce original element type, but both have vscale in common.
+    if (GCD < OrigElt.getSizeInBits())
+      return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
+                                 GCD);
+
+    return LLT::vector(
+        ElementCount::get(GCD / OrigElt.getSizeInBits().getFixedValue(),
+                          OrigTy.isScalable()),
+        OrigElt);
   }
 
   // If one type is vector and the element size matches the scalar size, then

>From b5f6a56382472038cb4c0b5557a27f1620e5bf97 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Mon, 5 Feb 2024 10:42:28 -0500
Subject: [PATCH 3/5] fixup! use getScalarType

Co-authored-by: Matt Arsenault <arsenm2 at gmail.com>
---
 llvm/lib/CodeGen/GlobalISel/Utils.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index 8fa3e47c74ac1..55896795c888a 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1182,8 +1182,8 @@ LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
   // vector and one is a scalar. If both types are scalars, the GCD type is the
   // GCD between the two scalar sizes. If one is vector and one is scalar, then
   // the GCD type is the GCD between the scalar and the vector element size.
-  LLT OrigScalar = OrigTy.isVector() ? OrigTy.getElementType() : OrigTy;
-  LLT TargetScalar = TargetTy.isVector() ? TargetTy.getElementType() : TargetTy;
+  LLT OrigScalar = OrigTy.getScalarType();
+  LLT TargetScalar = TargetTy.getScalarType();
   unsigned GCD = std::gcd(OrigScalar.getSizeInBits().getFixedValue(),
                           TargetScalar.getSizeInBits().getFixedValue());
   return LLT::scalar(GCD);

>From fce4d6b4cff60a0e0a612079f4f2e560990a7f9a Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Tue, 6 Feb 2024 07:06:49 -0800
Subject: [PATCH 4/5] fixup! simplify known min size calculation

---
 llvm/lib/CodeGen/GlobalISel/Utils.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index 55896795c888a..b884ec11c9b9f 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1150,10 +1150,8 @@ LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
       llvm_unreachable(
           "getGCDType not implemented between fixed and scalable vectors.");
 
-    unsigned GCD = std::gcd(OrigTy.getElementCount().getKnownMinValue() *
-                                OrigElt.getSizeInBits().getFixedValue(),
-                            TargetTy.getElementCount().getKnownMinValue() *
-                                TargetElt.getSizeInBits().getFixedValue());
+    unsigned GCD = std::gcd(OrigTy.getSizeInBits().getKnownMinValue(),
+                            TargetTy.getSizeInBits().getKnownMinValue());
     if (GCD == OrigElt.getSizeInBits())
       return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
                                  OrigElt);

>From 9f07add7864481e2eb4dc2821dcdc1ba2502c889 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Tue, 6 Feb 2024 07:08:57 -0800
Subject: [PATCH 5/5] !fixup use assert instead of unreachable

---
 llvm/lib/CodeGen/GlobalISel/Utils.cpp | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index b884ec11c9b9f..85e9af9625c9f 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1137,7 +1137,6 @@ LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
 
   if (OrigTy.isVector() && TargetTy.isVector()) {
     LLT OrigElt = OrigTy.getElementType();
-    LLT TargetElt = TargetTy.getElementType();
 
     // TODO: The docstring for this function says the intention is to use this
     // function to build MERGE/UNMERGE instructions. It won't be the case that
@@ -1145,10 +1144,9 @@ LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
     // could implement getGCDType between the two in the future if there was a
     // need, but it is not worth it now as this function should not be used in
     // that way.
-    if ((OrigTy.isScalableVector() && TargetTy.isFixedVector()) ||
-        (OrigTy.isFixedVector() && TargetTy.isScalableVector()))
-      llvm_unreachable(
-          "getGCDType not implemented between fixed and scalable vectors.");
+    assert(((OrigTy.isScalableVector() && !TargetTy.isFixedVector()) ||
+            (OrigTy.isFixedVector() && !TargetTy.isScalableVector())) &&
+           "getGCDType not implemented between fixed and scalable vectors.");
 
     unsigned GCD = std::gcd(OrigTy.getSizeInBits().getKnownMinValue(),
                             TargetTy.getSizeInBits().getKnownMinValue());



More information about the llvm-commits mailing list