[llvm] [AArch64] Fix failure with inline asm and svcount (PR #112537)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 24 09:23:36 PDT 2024


https://github.com/sdesmalen-arm updated https://github.com/llvm/llvm-project/pull/112537

>From d5250c6dcf017f31810f5a10eca8884a106d68c1 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Wed, 16 Oct 2024 09:56:14 +0100
Subject: [PATCH 1/2] [AArch64] Fix failure with inline asm and svcount

This fixes an issue where the compiler runs into an assertion
failure for the following example:

  register svcount_t pred asm("pn8") = svptrue_c8();
  asm("ld1w { z0.s, z4.s, z8.s, z12.s }, %[pred]/z, [x0]\n"
    :
    : [pred] "Uph" (pred)
    : "memory", "cc");

Here the register constraint that ends up in the LLVM IR is "{pn8}",
but the code in `TargetRegisterInfo::getRegForInlineAsmConstraint`
that parses that string, follows a path where it queries a
suitable register class for this register (<=> PPRorPNR regclass),
for which it then chooses `nxv16i1` as a suitable type. These
choices individually are correct, but the combined result isn't,
because the type should be `aarch64svcount`.
This then results in issues later on in SelectionDAGBuilder.cpp
in CopyToReg because the type of the actual value and the computed
type from the constraint don't match.

This PR pre-empts this issue by parsing the predicate explicitly
and returning the correct register class.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 32 ++++++++++
 llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll  | 64 +++++++++++++++++++
 2 files changed, 96 insertions(+)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9adb11292376ce..5a848ada9dd8ee 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -11867,6 +11867,36 @@ const char *AArch64TargetLowering::LowerXConstraint(EVT ConstraintVT) const {
 
 enum class PredicateConstraint { Uph, Upl, Upa };
 
+// Returns a {Reg, RegisterClass} tuple if the constraint is
+// a specific predicate register.
+//
+// For some constraint like "{pn3}" the default path in
+// TargetLowering::getRegForInlineAsmConstraint() leads it to determine that a
+// suitable register class for this register is "PPRorPNR", after which it
+// determines that nxv16i1 is an appropriate type for the constraint, which is
+// not what we want. The code here pre-empts this by matching the register
+// explicitly.
+static std::optional<std::pair<unsigned, const TargetRegisterClass *>>
+parsePredicateRegAsConstraint(StringRef Constraint) {
+  if (!Constraint.starts_with('{') || !Constraint.ends_with('}') ||
+      Constraint[1] != 'p')
+    return std::nullopt;
+
+  Constraint = Constraint.substr(2, Constraint.size() - 3);
+  bool IsPredicateAsCount = Constraint.starts_with("n");
+  if (IsPredicateAsCount)
+    Constraint = Constraint.drop_front(1);
+
+  unsigned V;
+  if (Constraint.getAsInteger(10, V) || V > 31)
+    return std::nullopt;
+
+  if (IsPredicateAsCount)
+    return std::make_pair(AArch64::PN0 + V, &AArch64::PNRRegClass);
+  else
+    return std::make_pair(AArch64::P0 + V, &AArch64::PPRRegClass);
+}
+
 static std::optional<PredicateConstraint>
 parsePredicateConstraint(StringRef Constraint) {
   return StringSwitch<std::optional<PredicateConstraint>>(Constraint)
@@ -12114,6 +12144,8 @@ AArch64TargetLowering::getRegForInlineAsmConstraint(
       break;
     }
   } else {
+    if (const auto P = parsePredicateRegAsConstraint(Constraint))
+      return *P;
     if (const auto PC = parsePredicateConstraint(Constraint))
       if (const auto *RegClass = getPredicateRegisterClass(*PC, VT))
         return std::make_pair(0U, RegClass);
diff --git a/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll b/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll
index 9f8897575b3d58..f587f315b6658d 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll
@@ -119,3 +119,67 @@ define <vscale x 8 x half> @test_svfadd_f16_Uph_constraint(<vscale x 16 x i1> %P
   %1 = tail call <vscale x 8 x half> asm "fadd $0.h, $1/m, $2.h, $3.h", "=w, at 3Uph,w,w"(<vscale x 16 x i1> %Pg, <vscale x 8 x half> %Zn, <vscale x 8 x half> %Zm)
   ret <vscale x 8 x half> %1
 }
+
+define void @explicit_p0(ptr %p) {
+  ; CHECK-LABEL: name: explicit_p0
+  ; CHECK: bb.0 (%ir-block.0):
+  ; CHECK-NEXT:   liveins: $x0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT:   [[COPY:%[0-9]+]]:gpr64 = COPY $x0
+  ; CHECK-NEXT:   [[PTRUE_B:%[0-9]+]]:ppr = PTRUE_B 31, implicit $vg
+  ; CHECK-NEXT:   $p0 = COPY [[PTRUE_B]]
+  ; CHECK-NEXT:   [[COPY1:%[0-9]+]]:gpr64common = COPY [[COPY]]
+  ; CHECK-NEXT:   INLINEASM &"ld4w { z0.s, z1.s, z2.s, z3.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 2818058 /* regdef:GPR64common */, def %1, 9 /* reguse */, $p0, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
+  ; CHECK-NEXT:   RET_ReallyLR
+  %1 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.b8(i32 31)
+  %2 = tail call i64 asm sideeffect "ld4w { z0.s, z1.s, z2.s, z3.s }, $1/z, [$0]", "=r,{p0},0"(<vscale x 16 x i1> %1, ptr %p)
+  ret void
+}
+
+define void @explicit_p8_invalid(ptr %p) {
+  ; CHECK-LABEL: name: explicit_p8_invalid
+  ; CHECK: bb.0 (%ir-block.0):
+  ; CHECK-NEXT:   liveins: $x0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT:   [[COPY:%[0-9]+]]:gpr64 = COPY $x0
+  ; CHECK-NEXT:   [[PTRUE_B:%[0-9]+]]:ppr = PTRUE_B 31, implicit $vg
+  ; CHECK-NEXT:   $p8 = COPY [[PTRUE_B]]
+  ; CHECK-NEXT:   [[COPY1:%[0-9]+]]:gpr64common = COPY [[COPY]]
+  ; CHECK-NEXT:   INLINEASM &"ld4w { z0.s, z1.s, z2.s, z3.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 2818058 /* regdef:GPR64common */, def %1, 9 /* reguse */, $p8, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
+  ; CHECK-NEXT:   RET_ReallyLR
+  %1 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.b8(i32 31)
+  %2 = tail call i64 asm sideeffect "ld4w { z0.s, z1.s, z2.s, z3.s }, $1/z, [$0]", "=r,{p8},0"(<vscale x 16 x i1> %1, ptr %p)
+  ret void
+}
+
+define void @explicit_pn8(ptr %p) {
+  ; CHECK-LABEL: name: explicit_pn8
+  ; CHECK: bb.0 (%ir-block.0):
+  ; CHECK-NEXT:   liveins: $x0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT:   [[COPY:%[0-9]+]]:gpr64 = COPY $x0
+  ; CHECK-NEXT:   [[PTRUE_C_B:%[0-9]+]]:pnr_p8to15 = PTRUE_C_B implicit $vg
+  ; CHECK-NEXT:   $pn8 = COPY [[PTRUE_C_B]]
+  ; CHECK-NEXT:   [[COPY1:%[0-9]+]]:gpr64common = COPY [[COPY]]
+  ; CHECK-NEXT:   INLINEASM &"ld1w { z0.s, z4.s, z8.s, z12.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 2818058 /* regdef:GPR64common */, def %1, 9 /* reguse */, $pn8, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
+  ; CHECK-NEXT:   RET_ReallyLR
+  %1 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
+  %2 = tail call i64 asm sideeffect "ld1w { z0.s, z4.s, z8.s, z12.s }, $1/z, [$0]", "=r,{pn8},0"(target("aarch64.svcount") %1, ptr %p)
+  ret void
+}
+
+define void @explicit_pn0_invalid(ptr %p) {
+  ; CHECK-LABEL: name: explicit_pn0_invalid
+  ; CHECK: bb.0 (%ir-block.0):
+  ; CHECK-NEXT:   liveins: $x0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT:   [[COPY:%[0-9]+]]:gpr64 = COPY $x0
+  ; CHECK-NEXT:   [[PTRUE_C_B:%[0-9]+]]:pnr_p8to15 = PTRUE_C_B implicit $vg
+  ; CHECK-NEXT:   $pn0 = COPY [[PTRUE_C_B]]
+  ; CHECK-NEXT:   [[COPY1:%[0-9]+]]:gpr64common = COPY [[COPY]]
+  ; CHECK-NEXT:   INLINEASM &"ld1w { z0.s, z4.s, z8.s, z12.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 2818058 /* regdef:GPR64common */, def %1, 9 /* reguse */, $pn0, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
+  ; CHECK-NEXT:   RET_ReallyLR
+  %1 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
+  %2 = tail call i64 asm sideeffect "ld1w { z0.s, z4.s, z8.s, z12.s }, $1/z, [$0]", "=r,{pn0},0"(target("aarch64.svcount") %1, ptr %p)
+  ret void
+}

>From 9da6284fda909371ec67debc9eb517a166eac130 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Thu, 24 Oct 2024 16:21:17 +0000
Subject: [PATCH 2/2] Fix up magical numbers in MIR test

---
 llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll b/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll
index f587f315b6658d..ff66206228a4aa 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll
@@ -129,7 +129,7 @@ define void @explicit_p0(ptr %p) {
   ; CHECK-NEXT:   [[PTRUE_B:%[0-9]+]]:ppr = PTRUE_B 31, implicit $vg
   ; CHECK-NEXT:   $p0 = COPY [[PTRUE_B]]
   ; CHECK-NEXT:   [[COPY1:%[0-9]+]]:gpr64common = COPY [[COPY]]
-  ; CHECK-NEXT:   INLINEASM &"ld4w { z0.s, z1.s, z2.s, z3.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 2818058 /* regdef:GPR64common */, def %1, 9 /* reguse */, $p0, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
+  ; CHECK-NEXT:   INLINEASM &"ld4w { z0.s, z1.s, z2.s, z3.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 3538954 /* regdef:GPR64common */, def %1, 9 /* reguse */, $p0, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
   ; CHECK-NEXT:   RET_ReallyLR
   %1 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.b8(i32 31)
   %2 = tail call i64 asm sideeffect "ld4w { z0.s, z1.s, z2.s, z3.s }, $1/z, [$0]", "=r,{p0},0"(<vscale x 16 x i1> %1, ptr %p)
@@ -145,7 +145,7 @@ define void @explicit_p8_invalid(ptr %p) {
   ; CHECK-NEXT:   [[PTRUE_B:%[0-9]+]]:ppr = PTRUE_B 31, implicit $vg
   ; CHECK-NEXT:   $p8 = COPY [[PTRUE_B]]
   ; CHECK-NEXT:   [[COPY1:%[0-9]+]]:gpr64common = COPY [[COPY]]
-  ; CHECK-NEXT:   INLINEASM &"ld4w { z0.s, z1.s, z2.s, z3.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 2818058 /* regdef:GPR64common */, def %1, 9 /* reguse */, $p8, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
+  ; CHECK-NEXT:   INLINEASM &"ld4w { z0.s, z1.s, z2.s, z3.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 3538954 /* regdef:GPR64common */, def %1, 9 /* reguse */, $p8, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
   ; CHECK-NEXT:   RET_ReallyLR
   %1 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.b8(i32 31)
   %2 = tail call i64 asm sideeffect "ld4w { z0.s, z1.s, z2.s, z3.s }, $1/z, [$0]", "=r,{p8},0"(<vscale x 16 x i1> %1, ptr %p)
@@ -161,7 +161,7 @@ define void @explicit_pn8(ptr %p) {
   ; CHECK-NEXT:   [[PTRUE_C_B:%[0-9]+]]:pnr_p8to15 = PTRUE_C_B implicit $vg
   ; CHECK-NEXT:   $pn8 = COPY [[PTRUE_C_B]]
   ; CHECK-NEXT:   [[COPY1:%[0-9]+]]:gpr64common = COPY [[COPY]]
-  ; CHECK-NEXT:   INLINEASM &"ld1w { z0.s, z4.s, z8.s, z12.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 2818058 /* regdef:GPR64common */, def %1, 9 /* reguse */, $pn8, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
+  ; CHECK-NEXT:   INLINEASM &"ld1w { z0.s, z4.s, z8.s, z12.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 3538954 /* regdef:GPR64common */, def %1, 9 /* reguse */, $pn8, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
   ; CHECK-NEXT:   RET_ReallyLR
   %1 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
   %2 = tail call i64 asm sideeffect "ld1w { z0.s, z4.s, z8.s, z12.s }, $1/z, [$0]", "=r,{pn8},0"(target("aarch64.svcount") %1, ptr %p)
@@ -177,7 +177,7 @@ define void @explicit_pn0_invalid(ptr %p) {
   ; CHECK-NEXT:   [[PTRUE_C_B:%[0-9]+]]:pnr_p8to15 = PTRUE_C_B implicit $vg
   ; CHECK-NEXT:   $pn0 = COPY [[PTRUE_C_B]]
   ; CHECK-NEXT:   [[COPY1:%[0-9]+]]:gpr64common = COPY [[COPY]]
-  ; CHECK-NEXT:   INLINEASM &"ld1w { z0.s, z4.s, z8.s, z12.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 2818058 /* regdef:GPR64common */, def %1, 9 /* reguse */, $pn0, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
+  ; CHECK-NEXT:   INLINEASM &"ld1w { z0.s, z4.s, z8.s, z12.s }, $1/z, [$0]", 1 /* sideeffect attdialect */, 3538954 /* regdef:GPR64common */, def %1, 9 /* reguse */, $pn0, 2147483657 /* reguse tiedto:$0 */, [[COPY1]](tied-def 3)
   ; CHECK-NEXT:   RET_ReallyLR
   %1 = tail call target("aarch64.svcount") @llvm.aarch64.sve.ptrue.c8()
   %2 = tail call i64 asm sideeffect "ld1w { z0.s, z4.s, z8.s, z12.s }, $1/z, [$0]", "=r,{pn0},0"(target("aarch64.svcount") %1, ptr %p)



More information about the llvm-commits mailing list