[llvm] 1f62af6 - [AArch64][SelectionDAG] Support passing/returning scalable vectors with unusual types.

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 2 15:53:41 PDT 2021


Author: Eli Friedman
Date: 2021-08-02T15:53:16-07:00
New Revision: 1f62af63467e4834e1e386619b3eccab245489d4

URL: https://github.com/llvm/llvm-project/commit/1f62af63467e4834e1e386619b3eccab245489d4
DIFF: https://github.com/llvm/llvm-project/commit/1f62af63467e4834e1e386619b3eccab245489d4.diff

LOG: [AArch64][SelectionDAG] Support passing/returning scalable vectors with unusual types.

This adds handling for two cases:

1. A scalable vector where the element type is promoted.
2. A scalable vector where the element count is odd (or more generally,
   not divisble by the element count of the part type).

(Some element types still don't work; for example, <vscale x 2 x i128>,
or <vscale x 2 x fp128>.)

Differential Revision: https://reviews.llvm.org/D105591

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/lib/CodeGen/TargetLoweringBase.cpp
    llvm/test/CodeGen/AArch64/sve-breakdown-scalable-vectortype.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index a085483939797..be00ec13a57df 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -399,29 +399,31 @@ static SDValue getCopyFromPartsVector(SelectionDAG &DAG, const SDLoc &DL,
     return Val;
 
   if (PartEVT.isVector()) {
+    // Vector/Vector bitcast.
+    if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits())
+      return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
+
     // If the element type of the source/dest vectors are the same, but the
     // parts vector has more elements than the value vector, then we have a
     // vector widening case (e.g. <2 x float> -> <4 x float>).  Extract the
     // elements we want.
-    if (PartEVT.getVectorElementType() == ValueVT.getVectorElementType()) {
+    if (PartEVT.getVectorElementCount() != ValueVT.getVectorElementCount()) {
       assert((PartEVT.getVectorElementCount().getKnownMinValue() >
               ValueVT.getVectorElementCount().getKnownMinValue()) &&
              (PartEVT.getVectorElementCount().isScalable() ==
               ValueVT.getVectorElementCount().isScalable()) &&
              "Cannot narrow, it would be a lossy transformation");
-      return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ValueVT, Val,
-                         DAG.getVectorIdxConstant(0, DL));
+      PartEVT =
+          EVT::getVectorVT(*DAG.getContext(), PartEVT.getVectorElementType(),
+                           ValueVT.getVectorElementCount());
+      Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, PartEVT, Val,
+                        DAG.getVectorIdxConstant(0, DL));
+      if (PartEVT == ValueVT)
+        return Val;
     }
 
-    // Vector/Vector bitcast.
-    if (ValueVT.getSizeInBits() == PartEVT.getSizeInBits())
-      return DAG.getNode(ISD::BITCAST, DL, ValueVT, Val);
-
-    assert(PartEVT.getVectorElementCount() == ValueVT.getVectorElementCount() &&
-      "Cannot handle this kind of promotion");
     // Promoted vector extract
     return DAG.getAnyExtOrTrunc(Val, DL, ValueVT);
-
   }
 
   // Trivial bitcast if the types are the same size and the destination
@@ -726,15 +728,19 @@ static void getCopyToPartsVector(SelectionDAG &DAG, const SDLoc &DL,
   } else if (ValueVT.getSizeInBits() == BuiltVectorTy.getSizeInBits()) {
     // Bitconvert vector->vector case.
     Val = DAG.getNode(ISD::BITCAST, DL, BuiltVectorTy, Val);
-  } else if (SDValue Widened =
-                 widenVectorToPartType(DAG, Val, DL, BuiltVectorTy)) {
-    Val = Widened;
-  } else if (BuiltVectorTy.getVectorElementType().bitsGE(
-                 ValueVT.getVectorElementType()) &&
-             BuiltVectorTy.getVectorElementCount() ==
-                 ValueVT.getVectorElementCount()) {
-    // Promoted vector extract
-    Val = DAG.getAnyExtOrTrunc(Val, DL, BuiltVectorTy);
+  } else {
+    if (BuiltVectorTy.getVectorElementType().bitsGT(
+            ValueVT.getVectorElementType())) {
+      // Integer promotion.
+      ValueVT = EVT::getVectorVT(*DAG.getContext(),
+                                 BuiltVectorTy.getVectorElementType(),
+                                 ValueVT.getVectorElementCount());
+      Val = DAG.getNode(ISD::ANY_EXTEND, DL, ValueVT, Val);
+    }
+
+    if (SDValue Widened = widenVectorToPartType(DAG, Val, DL, BuiltVectorTy)) {
+      Val = Widened;
+    }
   }
 
   assert(Val.getValueType() == BuiltVectorTy && "Unexpected vector value type");

diff  --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index 3c5dd29036db0..842280ab953e3 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -1556,7 +1556,7 @@ unsigned TargetLoweringBase::getVectorTypeBreakdown(LLVMContext &Context,
 
   // Scalable vectors cannot be scalarized, so handle the legalisation of the
   // types like done elsewhere in SelectionDAG.
-  if (VT.isScalableVector() && !isPowerOf2_32(EltCnt.getKnownMinValue())) {
+  if (EltCnt.isScalable()) {
     LegalizeKind LK;
     EVT PartVT = VT;
     do {
@@ -1565,16 +1565,14 @@ unsigned TargetLoweringBase::getVectorTypeBreakdown(LLVMContext &Context,
       PartVT = LK.second;
     } while (LK.first != TypeLegal);
 
-    NumIntermediates = VT.getVectorElementCount().getKnownMinValue() /
-                       PartVT.getVectorElementCount().getKnownMinValue();
+    if (!PartVT.isVector()) {
+      report_fatal_error(
+          "Don't know how to legalize this scalable vector type");
+    }
 
-    // FIXME: This code needs to be extended to handle more complex vector
-    // breakdowns, like nxv7i64 -> nxv8i64 -> 4 x nxv2i64. Currently the only
-    // supported cases are vectors that are broken down into equal parts
-    // such as nxv6i64 -> 3 x nxv2i64.
-    assert((PartVT.getVectorElementCount() * NumIntermediates) ==
-               VT.getVectorElementCount() &&
-           "Expected an integer multiple of PartVT");
+    NumIntermediates =
+        divideCeil(VT.getVectorElementCount().getKnownMinValue(),
+                   PartVT.getVectorElementCount().getKnownMinValue());
     IntermediateVT = PartVT;
     RegisterVT = getRegisterType(Context, IntermediateVT);
     return NumIntermediates;

diff  --git a/llvm/test/CodeGen/AArch64/sve-breakdown-scalable-vectortype.ll b/llvm/test/CodeGen/AArch64/sve-breakdown-scalable-vectortype.ll
index 2813b9d3719bf..63184513a769e 100644
--- a/llvm/test/CodeGen/AArch64/sve-breakdown-scalable-vectortype.ll
+++ b/llvm/test/CodeGen/AArch64/sve-breakdown-scalable-vectortype.ll
@@ -689,3 +689,64 @@ L1:
 L2:
   ret <vscale x 8 x double> %illegal
 }
+
+define <vscale x 8 x i63> @wide_8i63(i1 %b, <vscale x 16 x i8> %legal, <vscale x 8 x i63> %illegal) nounwind {
+; CHECK-LABEL: wide_8i63:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    tbnz w0, #0, .LBB21_2
+; CHECK-NEXT:  // %bb.1: // %L2
+; CHECK-NEXT:    mov z0.d, z1.d
+; CHECK-NEXT:    mov z1.d, z2.d
+; CHECK-NEXT:    mov z2.d, z3.d
+; CHECK-NEXT:    mov z3.d, z4.d
+; CHECK-NEXT:    ret
+; CHECK-NEXT:  .LBB21_2: // %L1
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl bar
+  br i1 %b, label %L1, label %L2
+L1:
+  call aarch64_sve_vector_pcs void @bar()
+  unreachable
+L2:
+  ret <vscale x 8 x i63> %illegal
+}
+
+define <vscale x 7 x i63> @wide_7i63(i1 %b, <vscale x 16 x i8> %legal, <vscale x 7 x i63> %illegal) nounwind {
+; CHECK-LABEL: wide_7i63:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    tbnz w0, #0, .LBB22_2
+; CHECK-NEXT:  // %bb.1: // %L2
+; CHECK-NEXT:    mov z0.d, z1.d
+; CHECK-NEXT:    mov z1.d, z2.d
+; CHECK-NEXT:    mov z2.d, z3.d
+; CHECK-NEXT:    mov z3.d, z4.d
+; CHECK-NEXT:    ret
+; CHECK-NEXT:  .LBB22_2: // %L1
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl bar
+  br i1 %b, label %L1, label %L2
+L1:
+  call aarch64_sve_vector_pcs void @bar()
+  unreachable
+L2:
+  ret <vscale x 7 x i63> %illegal
+}
+
+define <vscale x 7 x i31> @wide_7i31(i1 %b, <vscale x 16 x i8> %legal, <vscale x 7 x i31> %illegal) nounwind {
+; CHECK-LABEL: wide_7i31:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    tbnz w0, #0, .LBB23_2
+; CHECK-NEXT:  // %bb.1: // %L2
+; CHECK-NEXT:    mov z0.d, z1.d
+; CHECK-NEXT:    mov z1.d, z2.d
+; CHECK-NEXT:    ret
+; CHECK-NEXT:  .LBB23_2: // %L1
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl bar
+  br i1 %b, label %L1, label %L2
+L1:
+  call aarch64_sve_vector_pcs void @bar()
+  unreachable
+L2:
+  ret <vscale x 7 x i31> %illegal
+}


        


More information about the llvm-commits mailing list