[llvm] [GISel] Add support for scalable vectors in getGCDType (PR #80307)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 5 05:05:16 PST 2024


================
@@ -1132,45 +1132,60 @@ LLT llvm::getCoverTy(LLT OrigTy, LLT TargetTy) {
 }
 
 LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
-  const unsigned OrigSize = OrigTy.getSizeInBits();
-  const unsigned TargetSize = TargetTy.getSizeInBits();
-
-  if (OrigSize == TargetSize)
+  if (OrigTy.getSizeInBits() == TargetTy.getSizeInBits())
     return OrigTy;
 
-  if (OrigTy.isVector()) {
+  if (OrigTy.isVector() && TargetTy.isVector()) {
     LLT OrigElt = OrigTy.getElementType();
-    if (TargetTy.isVector()) {
-      LLT TargetElt = TargetTy.getElementType();
-      if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) {
-        int GCD = std::gcd(OrigTy.getNumElements(), TargetTy.getNumElements());
-        return LLT::scalarOrVector(ElementCount::getFixed(GCD), OrigElt);
-      }
-    } else {
-      // If the source is a vector of pointers, return a pointer element.
-      if (OrigElt.getSizeInBits() == TargetSize)
-        return OrigElt;
-    }
+    LLT TargetElt = TargetTy.getElementType();
 
-    unsigned GCD = std::gcd(OrigSize, TargetSize);
+    // TODO: The docstring for this function says the intention is to use this
+    // function to build MERGE/UNMERGE instructions. It won't be the case that
+    // we generate a MERGE/UNMERGE between fixed and scalable vector types. We
+    // could implement getGCDType between the two in the future if there was a
+    // need, but it is not worth it now as this function should not be used in
+    // that way.
+    if ((OrigTy.isScalableVector() && TargetTy.isFixedVector()) ||
+        (OrigTy.isFixedVector() && TargetTy.isScalableVector()))
+      llvm_unreachable(
+          "getGCDType not implemented between fixed and scalable vectors.");
+
+    unsigned GCD = std::gcd(OrigTy.getElementCount().getKnownMinValue() *
+                                OrigElt.getSizeInBits().getFixedValue(),
+                            TargetTy.getElementCount().getKnownMinValue() *
+                                TargetElt.getSizeInBits().getFixedValue());
     if (GCD == OrigElt.getSizeInBits())
-      return OrigElt;
+      return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
+                                 OrigElt);
 
-    // If we can't produce the original element type, we have to use a smaller
-    // scalar.
+    // Cannot produce original element type, but both have vscale in common.
     if (GCD < OrigElt.getSizeInBits())
-      return LLT::scalar(GCD);
-    return LLT::fixed_vector(GCD / OrigElt.getSizeInBits(), OrigElt);
-  }
+      return LLT::scalarOrVector(ElementCount::get(1, OrigTy.isScalable()),
+                                 GCD);
 
-  if (TargetTy.isVector()) {
-    // Try to preserve the original element type.
-    LLT TargetElt = TargetTy.getElementType();
-    if (TargetElt.getSizeInBits() == OrigSize)
-      return OrigTy;
+    return LLT::vector(
+        ElementCount::get(GCD / OrigElt.getSizeInBits().getFixedValue(),
+                          OrigTy.isScalable()),
+        OrigElt);
   }
 
-  unsigned GCD = std::gcd(OrigSize, TargetSize);
+  // If one type is vector and the element size matches the scalar size, then
+  // the gcd is the scalar type.
+  if (OrigTy.isVector() &&
+      OrigTy.getElementType().getSizeInBits() == TargetTy.getSizeInBits())
+    return OrigTy.getElementType();
+  if (TargetTy.isVector() &&
+      TargetTy.getElementType().getSizeInBits() == OrigTy.getSizeInBits())
+    return OrigTy;
+
+  // At this point, both types are either scalars of different type or one is a
+  // vector and one is a scalar. If both types are scalars, the GCD type is the
+  // GCD between the two scalar sizes. If one is vector and one is scalar, then
+  // the GCD type is the GCD between the scalar and the vector element size.
+  LLT OrigScalar = OrigTy.isVector() ? OrigTy.getElementType() : OrigTy;
+  LLT TargetScalar = TargetTy.isVector() ? TargetTy.getElementType() : TargetTy;
----------------
arsenm wrote:

```suggestion
  LLT OrigScalar = OrigTy.getScalarType();
  LLT TargetScalar = TargetTy.getScalarType();
```

https://github.com/llvm/llvm-project/pull/80307


More information about the llvm-commits mailing list