[llvm] [GlobalISel][AArch64] Legalize G_EXTRACT_VECTOR_ELT for SVE (PR #115161)

Thorsten Schütt via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 7 10:30:24 PST 2024


https://github.com/tschuett updated https://github.com/llvm/llvm-project/pull/115161

>From a0589d0f307a96493a7731ad4b7644cc443d84e7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Thorsten=20Sch=C3=BCtt?= <schuett at gmail.com>
Date: Sat, 2 Nov 2024 09:17:36 +0100
Subject: [PATCH] [GlobalISel][AArch64] Legalize G_EXTRACT_VECTOR_ELT for SVE

AArch64InstrGISel.td defines:
def : GINodeEquiv<G_EXTRACT_VECTOR_ELT, vector_extract>;

There are many patterns for SVE. Let's exploit that fact.
---
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  |   6 +-
 .../GISel/AArch64InstructionSelector.cpp      |   4 +-
 .../AArch64/GISel/AArch64LegalizerInfo.cpp    |  26 ++++-
 .../CodeGen/AArch64/extract-vector-elt-sve.ll | 105 ++++++++++++++++++
 4 files changed, 132 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index a87754389cc8ed..8cd3fa5f432b6e 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -3227,8 +3227,10 @@ bool IRTranslator::translateExtractElement(const User &U,
                                            MachineIRBuilder &MIRBuilder) {
   // If it is a <1 x Ty> vector, use the scalar as it is
   // not a legal vector type in LLT.
-  if (cast<FixedVectorType>(U.getOperand(0)->getType())->getNumElements() == 1)
-    return translateCopy(U, *U.getOperand(0), MIRBuilder);
+  if (const FixedVectorType *FVT =
+          dyn_cast<FixedVectorType>(U.getOperand(0)->getType()))
+    if (FVT->getNumElements() == 1)
+      return translateCopy(U, *U.getOperand(0), MIRBuilder);
 
   Register Res = getOrCreateVReg(U);
   Register Val = getOrCreateVReg(*U.getOperand(0));
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 9502b1d10f9a2b..663117c6b85bf7 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -5316,7 +5316,9 @@ bool AArch64InstructionSelector::selectUSMovFromExtend(
     return false;
   Register Src0 = Extract->getOperand(1).getReg();
 
-  const LLT &VecTy = MRI.getType(Src0);
+  const LLT VecTy = MRI.getType(Src0);
+  if (VecTy.isScalableVector())
+    return false;
 
   if (VecTy.getSizeInBits() != 128) {
     const MachineInstr *ScalarToVector = emitScalarToVector(
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index f7ca0ca65ac42b..3677cfdaba3b21 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -933,9 +933,16 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
         });
   }
 
+  // TODO : nxv4s16, nxv2s16, nxv2s32
   getActionDefinitionsBuilder(G_EXTRACT_VECTOR_ELT)
+      .legalFor(HasSVE, {{s16, nxv16s8, s64},
+                         {s16, nxv8s16, s64},
+                         {s32, nxv4s32, s64},
+                         {s64, nxv2s64, s64}})
       .unsupportedIf([=](const LegalityQuery &Query) {
         const LLT &EltTy = Query.Types[1].getElementType();
+        if (Query.Types[1].isScalableVector())
+          return false;
         return Query.Types[0] != EltTy;
       })
       .minScalar(2, s64)
@@ -949,22 +956,26 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
           [=](const LegalityQuery &Query) {
             // We want to promote to <M x s1> to <M x s64> if that wouldn't
             // cause the total vec size to be > 128b.
-            return Query.Types[1].getNumElements() <= 2;
+            return Query.Types[1].isFixedVector() &&
+                   Query.Types[1].getNumElements() <= 2;
           },
           0, s64)
       .minScalarOrEltIf(
           [=](const LegalityQuery &Query) {
-            return Query.Types[1].getNumElements() <= 4;
+            return Query.Types[1].isFixedVector() &&
+                   Query.Types[1].getNumElements() <= 4;
           },
           0, s32)
       .minScalarOrEltIf(
           [=](const LegalityQuery &Query) {
-            return Query.Types[1].getNumElements() <= 8;
+            return Query.Types[1].isFixedVector() &&
+                   Query.Types[1].getNumElements() <= 8;
           },
           0, s16)
       .minScalarOrEltIf(
           [=](const LegalityQuery &Query) {
-            return Query.Types[1].getNumElements() <= 16;
+            return Query.Types[1].isFixedVector() &&
+                   Query.Types[1].getNumElements() <= 16;
           },
           0, s8)
       .minScalarOrElt(0, s8) // Worst case, we need at least s8.
@@ -2178,11 +2189,14 @@ bool AArch64LegalizerInfo::legalizeMemOps(MachineInstr &MI,
 
 bool AArch64LegalizerInfo::legalizeExtractVectorElt(
     MachineInstr &MI, MachineRegisterInfo &MRI, LegalizerHelper &Helper) const {
-  assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT);
+  const GExtractVectorElement *Element = cast<GExtractVectorElement>(&MI);
   auto VRegAndVal =
-      getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
+      getIConstantVRegValWithLookThrough(Element->getIndexReg(), MRI);
   if (VRegAndVal)
     return true;
+  LLT VecTy = MRI.getType(Element->getVectorReg());
+  if (VecTy.isScalableVector())
+    return true;
   return Helper.lowerExtractInsertVectorElt(MI) !=
          LegalizerHelper::LegalizeResult::UnableToLegalize;
 }
diff --git a/llvm/test/CodeGen/AArch64/extract-vector-elt-sve.ll b/llvm/test/CodeGen/AArch64/extract-vector-elt-sve.ll
index 75c8f8923c3815..d18af3d5ae9450 100644
--- a/llvm/test/CodeGen/AArch64/extract-vector-elt-sve.ll
+++ b/llvm/test/CodeGen/AArch64/extract-vector-elt-sve.ll
@@ -121,3 +121,108 @@ entry:
   %d = insertelement <vscale x 16 x i8> %vec, i8 %elt, i64 %idx
   ret <vscale x 16 x i8> %d
 }
+
+define i64 @extract_vscale_2_i64(<vscale x 2 x i64> %vec, i64 %idx) {
+; CHECK-SD-LABEL: extract_vscale_2_i64:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    whilels p0.d, xzr, x0
+; CHECK-SD-NEXT:    lastb x0, p0, z0.d
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_vscale_2_i64:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    whilels p0.d, xzr, x0
+; CHECK-GI-NEXT:    lastb d0, p0, z0.d
+; CHECK-GI-NEXT:    fmov x0, d0
+; CHECK-GI-NEXT:    ret
+entry:
+  %d = extractelement <vscale x 2 x i64> %vec, i64 %idx
+  ret i64 %d
+}
+
+define i64 @extract_vscale_2_i64_zero(<vscale x 2 x i64> %vec, i64 %idx) {
+; CHECK-LABEL: extract_vscale_2_i64_zero:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fmov x0, d0
+; CHECK-NEXT:    ret
+entry:
+  %d = extractelement <vscale x 2 x i64> %vec, i64 0
+  ret i64 %d
+}
+
+define i32 @extract_vscale_4_i32(<vscale x 4 x i32> %vec, i64 %idx) {
+; CHECK-SD-LABEL: extract_vscale_4_i32:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    whilels p0.s, xzr, x0
+; CHECK-SD-NEXT:    lastb w0, p0, z0.s
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_vscale_4_i32:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    whilels p0.s, xzr, x0
+; CHECK-GI-NEXT:    lastb s0, p0, z0.s
+; CHECK-GI-NEXT:    fmov w0, s0
+; CHECK-GI-NEXT:    ret
+entry:
+  %d = extractelement <vscale x 4 x i32> %vec, i64 %idx
+  ret i32 %d
+}
+
+define i32 @extract_vscale_4_i32_zero(<vscale x 4 x i32> %vec, i64 %idx) {
+; CHECK-LABEL: extract_vscale_4_i32_zero:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fmov w0, s0
+; CHECK-NEXT:    ret
+entry:
+  %d = extractelement <vscale x 4 x i32> %vec, i64 0
+  ret i32 %d
+}
+
+define i16 @extract_vscale_8_i16(<vscale x 8 x i16> %vec, i64 %idx) {
+; CHECK-SD-LABEL: extract_vscale_8_i16:
+; CHECK-SD:       // %bb.0: // %entry
+; CHECK-SD-NEXT:    whilels p0.h, xzr, x0
+; CHECK-SD-NEXT:    lastb w0, p0, z0.h
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: extract_vscale_8_i16:
+; CHECK-GI:       // %bb.0: // %entry
+; CHECK-GI-NEXT:    whilels p0.h, xzr, x0
+; CHECK-GI-NEXT:    lastb h0, p0, z0.h
+; CHECK-GI-NEXT:    fmov w0, s0
+; CHECK-GI-NEXT:    ret
+entry:
+  %d = extractelement <vscale x 8 x i16> %vec, i64 %idx
+  ret i16 %d
+}
+
+define i16 @extract_vscale_8_i16_zero(<vscale x 8 x i16> %vec, i64 %idx) {
+; CHECK-LABEL: extract_vscale_8_i16_zero:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fmov w0, s0
+; CHECK-NEXT:    ret
+entry:
+  %d = extractelement <vscale x 8 x i16> %vec, i64 0
+  ret i16 %d
+}
+
+define i8 @extract_vscale_16_i8(<vscale x 16 x i8> %vec, i64 %idx) {
+; CHECK-LABEL: extract_vscale_16_i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    whilels p0.b, xzr, x0
+; CHECK-NEXT:    lastb w0, p0, z0.b
+; CHECK-NEXT:    ret
+entry:
+  %d = extractelement <vscale x 16 x i8> %vec, i64 %idx
+  ret i8 %d
+}
+
+define i8 @extract_vscale_16_i8_zero(<vscale x 16 x i8> %vec, i64 %idx) {
+; CHECK-LABEL: extract_vscale_16_i8_zero:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fmov w0, s0
+; CHECK-NEXT:    ret
+entry:
+  %d = extractelement <vscale x 16 x i8> %vec, i64 0
+  ret i8 %d
+}



More information about the llvm-commits mailing list