[Lldb-commits] [lldb] e96adfd - [lldb][AArch64] Add testing for SME's ZA and SVG registers

David Spickett via lldb-commits lldb-commits at lists.llvm.org
Tue Sep 19 05:18:29 PDT 2023


Author: David Spickett
Date: 2023-09-19T12:18:23Z
New Revision: e96adfd0dbcf84c27be2087371890f4228890609

URL: https://github.com/llvm/llvm-project/commit/e96adfd0dbcf84c27be2087371890f4228890609
DIFF: https://github.com/llvm/llvm-project/commit/e96adfd0dbcf84c27be2087371890f4228890609.diff

LOG: [lldb][AArch64] Add testing for SME's ZA and SVG registers

An SME enabled program has the following extra state:
* Streaming mode or non-streaming mode.
* ZA enabled or disabled.
* The active vector length.

Covering the transition between all possible states and all other
possible states is not viable, therefore the testing added here is a cross
section of that, all of which found real bugs in LLDB and the Linux
Kernel during development.

Many of those transitions will not be possible via LLDB
(e.g. disabling ZA) and many more are possible but unlikely to be
used in normal use.

Added testing:
* TestSVEThreadedDynamic now checks for correct SVG values.
* New test TestZAThreadedDynamic creates 3 threads with different ZA sizes
  and states and switches between them verifying the register value
  (derived from the existing threaded SVE test).
* New test TestZARegisterSaveRestore starts in a given SME state, runs a
  set of expressions in various orders, then checks that the original
  state has been restored.
* TestArm64DynamicRegsets has ZA and SVG checks added, including writing
  to ZA to enable it.

Running these tests will as usual require QEMU as there is no
real SME hardware available at this time, and a very recent
kernel.

Reviewed By: omjavaid

Differential Revision: https://reviews.llvm.org/D159505

Added: 
    lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/Makefile
    lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/TestZAThreadedDynamic.py
    lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/main.c
    lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/Makefile
    lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/TestZARegisterSaveRestore.py
    lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/main.c

Modified: 
    lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py
    lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py

Removed: 
    


################################################################################
diff  --git a/lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py b/lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py
index d3f53f0e95dfcb5..4f4da2b5223fb15 100644
--- a/lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py
+++ b/lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py
@@ -70,15 +70,13 @@ def sve_regs_read_dynamic(self, sve_registers):
         self.runCmd("register write ffr " + "'" + p_regs_value + "'")
         self.expect("register read ffr", substrs=[p_regs_value])
 
-    @no_debug_info_test
-    @skipIf(archs=no_match(["aarch64"]))
-    @skipIf(oslist=no_match(["linux"]))
-    def test_aarch64_dynamic_regset_config(self):
-        """Test AArch64 Dynamic Register sets configuration."""
+    def setup_register_config_test(self, run_args=None):
         self.build()
         self.line = line_number("main.c", "// Set a break point here.")
 
         exe = self.getBuildArtifact("a.out")
+        if run_args is not None:
+            self.runCmd("settings set target.run-args " + run_args)
         self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
 
         lldbutil.run_break_set_by_file_and_line(
@@ -92,12 +90,16 @@ def test_aarch64_dynamic_regset_config(self):
             substrs=["stop reason = breakpoint 1."],
         )
 
-        target = self.dbg.GetSelectedTarget()
-        process = target.GetProcess()
-        thread = process.GetThreadAtIndex(0)
-        currentFrame = thread.GetFrameAtIndex(0)
+        return self.thread().GetSelectedFrame().GetRegisters()
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_aarch64_dynamic_regset_config(self):
+        """Test AArch64 Dynamic Register sets configuration."""
+        register_sets = self.setup_register_config_test()
 
-        for registerSet in currentFrame.GetRegisters():
+        for registerSet in register_sets:
             if "Scalable Vector Extension Registers" in registerSet.GetName():
                 self.assertTrue(
                     self.isAArch64SVE(),
@@ -120,6 +122,19 @@ def test_aarch64_dynamic_regset_config(self):
                 )
                 self.expect("register read data_mask", substrs=["data_mask = 0x"])
                 self.expect("register read code_mask", substrs=["code_mask = 0x"])
+            if "Scalable Matrix Extension Registers" in registerSet.GetName():
+                self.assertTrue(
+                    self.isAArch64SME(),
+                    "LLDB Enabled SME register set when it was disabled by target",
+                )
+
+    def make_za_value(self, vl, generator):
+        # Generate a vector value string "{0x00 0x01....}".
+        rows = []
+        for row in range(vl):
+            byte = "0x{:02x}".format(generator(row))
+            rows.append(" ".join([byte] * vl))
+        return "{" + " ".join(rows) + "}"
 
     @no_debug_info_test
     @skipIf(archs=no_match(["aarch64"]))
@@ -130,28 +145,58 @@ def test_aarch64_dynamic_regset_config_sme(self):
         if not self.isAArch64SME():
             self.skipTest("SME must be present.")
 
-        self.build()
-        self.line = line_number("main.c", "// Set a break point here.")
+        register_sets = self.setup_register_config_test("sme")
 
-        exe = self.getBuildArtifact("a.out")
-        self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
-
-        lldbutil.run_break_set_by_file_and_line(
-            self, "main.c", self.line, num_expected_locations=1
+        ssve_registers = register_sets.GetFirstValueByName(
+            "Scalable Vector Extension Registers"
         )
-        self.runCmd("settings set target.run-args sme")
-        self.runCmd("run", RUN_SUCCEEDED)
+        self.assertTrue(ssve_registers.IsValid())
+        self.sve_regs_read_dynamic(ssve_registers)
 
-        self.expect(
-            "thread backtrace",
-            STOPPED_DUE_TO_BREAKPOINT,
-            substrs=["stop reason = breakpoint 1."],
+        sme_registers = register_sets.GetFirstValueByName(
+            "Scalable Matrix Extension Registers"
         )
+        self.assertTrue(sme_registers.IsValid())
 
-        register_sets = self.thread().GetSelectedFrame().GetRegisters()
+        vg = ssve_registers.GetChildMemberWithName("vg").GetValueAsUnsigned()
+        vl = vg * 8
+        # When first enabled it is all 0s.
+        self.expect("register read za", substrs=[self.make_za_value(vl, lambda r: 0)])
+        za_value = self.make_za_value(vl, lambda r: r + 1)
+        self.runCmd("register write za '{}'".format(za_value))
+        self.expect("register read za", substrs=[za_value])
 
-        ssve_registers = register_sets.GetFirstValueByName(
-            "Scalable Vector Extension Registers"
+        # SVG should match VG because we're in streaming mode.
+
+        self.assertTrue(sme_registers.IsValid())
+        svg = sme_registers.GetChildMemberWithName("svg").GetValueAsUnsigned()
+        self.assertEqual(vg, svg)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_aarch64_dynamic_regset_config_sme_za_disabled(self):
+        """Test that ZA shows as 0s when disabled and can be enabled by writing
+        to it."""
+        if not self.isAArch64SME():
+            self.skipTest("SME must be present.")
+
+        # No argument, so ZA will be disabled when we break.
+        register_sets = self.setup_register_config_test()
+
+        # vg is the non-streaming vg as we are in non-streaming mode, so we need
+        # to use svg.
+        sme_registers = register_sets.GetFirstValueByName(
+            "Scalable Matrix Extension Registers"
         )
-        self.assertTrue(ssve_registers.IsValid())
-        self.sve_regs_read_dynamic(ssve_registers)
+        self.assertTrue(sme_registers.IsValid())
+        svg = sme_registers.GetChildMemberWithName("svg").GetValueAsUnsigned()
+
+        svl = svg * 8
+        # A disabled ZA is shown as all 0s.
+        self.expect("register read za", substrs=[self.make_za_value(svl, lambda r: 0)])
+        za_value = self.make_za_value(svl, lambda r: r + 1)
+        # Writing to it enables ZA, so the value should be there when we read
+        # it back.
+        self.runCmd("register write za '{}'".format(za_value))
+        self.expect("register read za", substrs=[za_value])

diff  --git a/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py b/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py
index ecac3712674976b..8bcb76776459d01 100644
--- a/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py
+++ b/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py
@@ -98,6 +98,12 @@ def check_sve_registers(self, vg_test_value):
 
         self.expect("register read ffr", substrs=[p_regs_value])
 
+    def build_for_mode(self, mode):
+        cflags = "-march=armv8-a+sve -lpthread"
+        if mode == Mode.SSVE:
+            cflags += " -DUSE_SSVE"
+        self.build(dictionary={"CFLAGS_EXTRAS": cflags})
+
     def run_sve_test(self, mode):
         if (mode == Mode.SVE) and not self.isAArch64SVE():
             self.skipTest("SVE registers must be supported.")
@@ -105,12 +111,8 @@ def run_sve_test(self, mode):
         if (mode == Mode.SSVE) and not self.isAArch64SME():
             self.skipTest("Streaming SVE registers must be supported.")
 
-        cflags = "-march=armv8-a+sve -lpthread"
-        if mode == Mode.SSVE:
-            cflags += " -DUSE_SSVE"
-        self.build(dictionary={"CFLAGS_EXTRAS": cflags})
+        self.build_for_mode(mode)
 
-        self.build()
         supported_vg = self.get_supported_vg()
 
         if not (2 in supported_vg and 4 in supported_vg):
@@ -196,3 +198,94 @@ def test_sve_registers_dynamic_config(self):
     def test_ssve_registers_dynamic_config(self):
         """Test AArch64 SSVE registers multi-threaded dynamic resize."""
         self.run_sve_test(Mode.SSVE)
+
+    def setup_svg_test(self, mode):
+        # Even when running in SVE mode, we need access to SVG for these tests.
+        if not self.isAArch64SME():
+            self.skipTest("Streaming SVE registers must be present.")
+
+        self.build_for_mode(mode)
+
+        supported_vg = self.get_supported_vg()
+
+        main_thread_stop_line = line_number("main.c", "// Break in main thread")
+        lldbutil.run_break_set_by_file_and_line(self, "main.c", main_thread_stop_line)
+
+        self.runCmd("run", RUN_SUCCEEDED)
+
+        self.expect(
+            "thread info 1",
+            STOPPED_DUE_TO_BREAKPOINT,
+            substrs=["stop reason = breakpoint"],
+        )
+
+        target = self.dbg.GetSelectedTarget()
+        process = target.GetProcess()
+
+        return process, supported_vg
+
+    def read_reg(self, process, regset, reg):
+        registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters()
+        sve_registers = registerSets.GetFirstValueByName(regset)
+        return sve_registers.GetChildMemberWithName(reg).GetValueAsUnsigned()
+
+    def read_vg(self, process):
+        return self.read_reg(process, "Scalable Vector Extension Registers", "vg")
+
+    def read_svg(self, process):
+        return self.read_reg(process, "Scalable Matrix Extension Registers", "svg")
+
+    def do_svg_test(self, process, vgs, expected_svgs):
+        for vg, svg in zip(vgs, expected_svgs):
+            self.runCmd("register write vg {}".format(vg))
+            self.assertEqual(svg, self.read_svg(process))
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_svg_sve_mode(self):
+        """When in SVE mode, svg should remain constant as we change vg."""
+        process, supported_vg = self.setup_svg_test(Mode.SVE)
+        svg = self.read_svg(process)
+        self.do_svg_test(process, supported_vg, [svg] * len(supported_vg))
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_svg_ssve_mode(self):
+        """When in SSVE mode, changing vg should change svg to the same value."""
+        process, supported_vg = self.setup_svg_test(Mode.SSVE)
+        self.do_svg_test(process, supported_vg, supported_vg)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_sme_not_present(self):
+        """When there is no SME, we should not show the SME register sets."""
+        if self.isAArch64SME():
+            self.skipTest("Streaming SVE registers must not be present.")
+
+        self.build_for_mode(Mode.SVE)
+
+        exe = self.getBuildArtifact("a.out")
+        self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
+
+        # This test may run on a non-sve system, but we'll stop before any
+        # SVE instruction would be run.
+        self.runCmd("b main")
+        self.runCmd("run", RUN_SUCCEEDED)
+
+        self.expect(
+            "thread info 1",
+            STOPPED_DUE_TO_BREAKPOINT,
+            substrs=["stop reason = breakpoint"],
+        )
+
+        target = self.dbg.GetSelectedTarget()
+        process = target.GetProcess()
+
+        registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters()
+        sme_registers = registerSets.GetFirstValueByName(
+            "Scalable Matrix Extension Registers"
+        )
+        self.assertFalse(sme_registers.IsValid())

diff  --git a/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/Makefile b/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/Makefile
new file mode 100644
index 000000000000000..57d926b37d45cf4
--- /dev/null
+++ b/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/Makefile
@@ -0,0 +1,5 @@
+C_SOURCES := main.c
+
+CFLAGS_EXTRAS := -march=armv8-a+sve+sme -lpthread
+
+include Makefile.rules

diff  --git a/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/TestZAThreadedDynamic.py b/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/TestZAThreadedDynamic.py
new file mode 100644
index 000000000000000..65d1071c26b2a34
--- /dev/null
+++ b/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/TestZAThreadedDynamic.py
@@ -0,0 +1,165 @@
+"""
+Test the AArch64 SME Array Storage (ZA) register dynamic resize with
+multiple threads.
+"""
+
+from enum import Enum
+import lldb
+from lldbsuite.test.decorators import *
+from lldbsuite.test.lldbtest import *
+from lldbsuite.test import lldbutil
+
+
+class AArch64ZAThreadedTestCase(TestBase):
+    def get_supported_vg(self):
+        exe = self.getBuildArtifact("a.out")
+        self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
+
+        main_thread_stop_line = line_number("main.c", "// Break in main thread")
+        lldbutil.run_break_set_by_file_and_line(self, "main.c", main_thread_stop_line)
+
+        self.runCmd("settings set target.run-args 0")
+        self.runCmd("run", RUN_SUCCEEDED)
+
+        self.expect(
+            "thread info 1",
+            STOPPED_DUE_TO_BREAKPOINT,
+            substrs=["stop reason = breakpoint"],
+        )
+
+        current_vg = self.match("register read vg", ["(0x[0-9]+)"])
+        self.assertTrue(current_vg is not None)
+        self.expect("register write vg {}".format(current_vg.group()))
+
+        # Aka 128, 256 and 512 bit.
+        supported_vg = []
+        for vg in [2, 4, 8]:
+            # This could mask other errors but writing vg is tested elsewhere
+            # so we assume the hardware rejected the value.
+            self.runCmd("register write vg {}".format(vg), check=False)
+            if not self.res.GetError():
+                supported_vg.append(vg)
+
+        self.runCmd("breakpoint delete 1")
+        self.runCmd("continue")
+
+        return supported_vg
+
+    def gen_za_value(self, svg, value_generator):
+        svl = svg * 8
+
+        rows = []
+        for row in range(svl):
+            byte = "0x{:02x}".format(value_generator(row))
+            rows.append(" ".join([byte] * svl))
+
+        return "{" + " ".join(rows) + "}"
+
+    def check_za_register(self, svg, value_offset):
+        self.expect(
+            "register read za",
+            substrs=[self.gen_za_value(svg, lambda r: r + value_offset)],
+        )
+
+    def check_disabled_za_register(self, svg):
+        self.expect("register read za", substrs=[self.gen_za_value(svg, lambda r: 0)])
+
+    def za_test_impl(self, enable_za):
+        if not self.isAArch64SME():
+            self.skipTest("SME must be present.")
+
+        self.build()
+        supported_vg = self.get_supported_vg()
+
+        self.runCmd("settings set target.run-args {}".format("1" if enable_za else "0"))
+
+        if not (2 in supported_vg and 4 in supported_vg):
+            self.skipTest("Not all required streaming vector lengths are supported.")
+
+        main_thread_stop_line = line_number("main.c", "// Break in main thread")
+        lldbutil.run_break_set_by_file_and_line(self, "main.c", main_thread_stop_line)
+
+        thX_break_line1 = line_number("main.c", "// Thread X breakpoint 1")
+        lldbutil.run_break_set_by_file_and_line(self, "main.c", thX_break_line1)
+
+        thX_break_line2 = line_number("main.c", "// Thread X breakpoint 2")
+        lldbutil.run_break_set_by_file_and_line(self, "main.c", thX_break_line2)
+
+        thY_break_line1 = line_number("main.c", "// Thread Y breakpoint 1")
+        lldbutil.run_break_set_by_file_and_line(self, "main.c", thY_break_line1)
+
+        thY_break_line2 = line_number("main.c", "// Thread Y breakpoint 2")
+        lldbutil.run_break_set_by_file_and_line(self, "main.c", thY_break_line2)
+
+        self.runCmd("run", RUN_SUCCEEDED)
+
+        self.expect(
+            "thread info 1",
+            STOPPED_DUE_TO_BREAKPOINT,
+            substrs=["stop reason = breakpoint"],
+        )
+
+        if 8 in supported_vg:
+            if enable_za:
+                self.check_za_register(8, 1)
+            else:
+                self.check_disabled_za_register(8)
+        else:
+            if enable_za:
+                self.check_za_register(4, 1)
+            else:
+                self.check_disabled_za_register(4)
+
+        self.runCmd("process continue", RUN_SUCCEEDED)
+
+        process = self.dbg.GetSelectedTarget().GetProcess()
+        for idx in range(1, process.GetNumThreads()):
+            thread = process.GetThreadAtIndex(idx)
+            if thread.GetStopReason() != lldb.eStopReasonBreakpoint:
+                self.runCmd("thread continue %d" % (idx + 1))
+                self.assertEqual(thread.GetStopReason(), lldb.eStopReasonBreakpoint)
+
+            stopped_at_line_number = thread.GetFrameAtIndex(0).GetLineEntry().GetLine()
+
+            if stopped_at_line_number == thX_break_line1:
+                self.runCmd("thread select %d" % (idx + 1))
+                self.check_za_register(4, 2)
+                self.runCmd("register write vg 2")
+
+            elif stopped_at_line_number == thY_break_line1:
+                self.runCmd("thread select %d" % (idx + 1))
+                self.check_za_register(2, 3)
+                self.runCmd("register write vg 4")
+
+        self.runCmd("thread continue 2")
+        self.runCmd("thread continue 3")
+
+        for idx in range(1, process.GetNumThreads()):
+            thread = process.GetThreadAtIndex(idx)
+            self.assertEqual(thread.GetStopReason(), lldb.eStopReasonBreakpoint)
+
+            stopped_at_line_number = thread.GetFrameAtIndex(0).GetLineEntry().GetLine()
+
+            if stopped_at_line_number == thX_break_line2:
+                self.runCmd("thread select %d" % (idx + 1))
+                self.check_za_register(2, 2)
+
+            elif stopped_at_line_number == thY_break_line2:
+                self.runCmd("thread select %d" % (idx + 1))
+                self.check_za_register(4, 3)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_register_dynamic_config_main_enabled(self):
+        """Test multiple threads resizing ZA, with the main thread's ZA
+        enabled."""
+        self.za_test_impl(True)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_register_dynamic_config_main_disabled(self):
+        """Test multiple threads resizing ZA, with the main thread's ZA
+        disabled."""
+        self.za_test_impl(False)

diff  --git a/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/main.c b/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/main.c
new file mode 100644
index 000000000000000..fd2590dbe411f7f
--- /dev/null
+++ b/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/main.c
@@ -0,0 +1,104 @@
+#include <pthread.h>
+#include <stdatomic.h>
+#include <stdbool.h>
+#include <stdint.h>
+#include <string.h>
+#include <sys/prctl.h>
+
+// Important notes for this test:
+// * Making a syscall will disable streaming mode.
+// * LLDB writing to vg while in streaming mode will disable ZA
+//   (this is just how ptrace works).
+// * Using an instruction to write to an inactive ZA produces a SIGILL
+//   (doing the same thing via ptrace does not, as the kernel activates ZA for
+//   us in that case).
+
+#ifndef PR_SME_SET_VL
+#define PR_SME_SET_VL 63
+#endif
+
+#define SM_INST(c) asm volatile("msr s0_3_c4_c" #c "_3, xzr")
+#define SMSTART_SM SM_INST(3)
+#define SMSTART_ZA SM_INST(5)
+
+void set_za_register(int svl, int value_offset) {
+#define MAX_VL_BYTES 256
+  uint8_t data[MAX_VL_BYTES];
+
+  // ldr za will actually wrap the selected vector row, by the number of rows
+  // you have. So setting one that didn't exist would actually set one that did.
+  // That's why we need the streaming vector length here.
+  for (int i = 0; i < svl; ++i) {
+    memset(data, i + value_offset, MAX_VL_BYTES);
+    // Each one of these loads a VL sized row of ZA.
+    asm volatile("mov w12, %w0\n\t"
+                 "ldr za[w12, 0], [%1]\n\t" ::"r"(i),
+                 "r"(&data)
+                 : "w12");
+  }
+}
+
+// These are used to make sure we only break in each thread once both of the
+// threads have been started. Otherwise when the test does "process continue"
+// it could stop in one thread and wait forever for the other one to start.
+atomic_bool threadX_ready = false;
+atomic_bool threadY_ready = false;
+
+void *threadX_func(void *x_arg) {
+  threadX_ready = true;
+  while (!threadY_ready) {
+  }
+
+  prctl(PR_SME_SET_VL, 8 * 4);
+  SMSTART_SM;
+  SMSTART_ZA;
+  set_za_register(8 * 4, 2);
+  SMSTART_ZA; // Thread X breakpoint 1
+  set_za_register(8 * 2, 2);
+  return NULL; // Thread X breakpoint 2
+}
+
+void *threadY_func(void *y_arg) {
+  threadY_ready = true;
+  while (!threadX_ready) {
+  }
+
+  prctl(PR_SME_SET_VL, 8 * 2);
+  SMSTART_SM;
+  SMSTART_ZA;
+  set_za_register(8 * 2, 3);
+  SMSTART_ZA; // Thread Y breakpoint 1
+  set_za_register(8 * 4, 3);
+  return NULL; // Thread Y breakpoint 2
+}
+
+int main(int argc, char *argv[]) {
+  // Expecting argument to tell us whether to enable ZA on the main thread.
+  if (argc != 2)
+    return 1;
+
+  prctl(PR_SME_SET_VL, 8 * 8);
+  SMSTART_SM;
+
+  if (argv[1][0] == '1') {
+    SMSTART_ZA;
+    set_za_register(8 * 8, 1);
+  }
+  // else we do not enable ZA and lldb will show 0s for it.
+
+  pthread_t x_thread;
+  if (pthread_create(&x_thread, NULL, threadX_func, 0)) // Break in main thread
+    return 1;
+
+  pthread_t y_thread;
+  if (pthread_create(&y_thread, NULL, threadY_func, 0))
+    return 1;
+
+  if (pthread_join(x_thread, NULL))
+    return 2;
+
+  if (pthread_join(y_thread, NULL))
+    return 2;
+
+  return 0;
+}

diff  --git a/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/Makefile b/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/Makefile
new file mode 100644
index 000000000000000..f2ca08f3531aa16
--- /dev/null
+++ b/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/Makefile
@@ -0,0 +1,5 @@
+C_SOURCES := main.c
+
+CFLAGS_EXTRAS := -march=armv8-a+sve+sme
+
+include Makefile.rules

diff  --git a/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/TestZARegisterSaveRestore.py b/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/TestZARegisterSaveRestore.py
new file mode 100644
index 000000000000000..1d4bbd6207a51c1
--- /dev/null
+++ b/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/TestZARegisterSaveRestore.py
@@ -0,0 +1,252 @@
+"""
+Test the AArch64 SME ZA register is saved and restored around expressions.
+
+This attempts to cover expressions that change the following:
+* ZA enabled or not.
+* Streaming mode or not.
+* Streaming vector length (increasing and decreasing).
+* Some combintations of the above.
+"""
+
+from enum import IntEnum
+import lldb
+from lldbsuite.test.decorators import *
+from lldbsuite.test.lldbtest import *
+from lldbsuite.test import lldbutil
+
+
+# These enum values match the flag values used in the test program.
+class Mode(IntEnum):
+    SVE = 0
+    SSVE = 1
+
+
+class ZA(IntEnum):
+    Disabled = 0
+    Enabled = 1
+
+
+class AArch64ZATestCase(TestBase):
+    def get_supported_svg(self):
+        # Always build this probe program to start as streaming SVE.
+        # We will read/write "vg" here but since we are in streaming mode "svg"
+        # is really what we are writing ("svg" is a read only pseudo).
+        self.build()
+
+        exe = self.getBuildArtifact("a.out")
+        self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
+        # Enter streaming mode, don't enable ZA, start_vl and other_vl don't
+        # matter here.
+        self.runCmd("settings set target.run-args 1 0 0 0")
+
+        stop_line = line_number("main.c", "// Set a break point here.")
+        lldbutil.run_break_set_by_file_and_line(
+            self, "main.c", stop_line, num_expected_locations=1
+        )
+
+        self.runCmd("run", RUN_SUCCEEDED)
+
+        self.expect(
+            "thread info 1",
+            STOPPED_DUE_TO_BREAKPOINT,
+            substrs=["stop reason = breakpoint"],
+        )
+
+        # Write back the current vg to confirm read/write works at all.
+        current_svg = self.match("register read vg", ["(0x[0-9]+)"])
+        self.assertTrue(current_svg is not None)
+        self.expect("register write vg {}".format(current_svg.group()))
+
+        # Aka 128, 256 and 512 bit.
+        supported_svg = []
+        for svg in [2, 4, 8]:
+            # This could mask other errors but writing vg is tested elsewhere
+            # so we assume the hardware rejected the value.
+            self.runCmd("register write vg {}".format(svg), check=False)
+            if not self.res.GetError():
+                supported_svg.append(svg)
+
+        self.runCmd("breakpoint delete 1")
+        self.runCmd("continue")
+
+        return supported_svg
+
+    def read_vg(self):
+        process = self.dbg.GetSelectedTarget().GetProcess()
+        registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters()
+        sve_registers = registerSets.GetFirstValueByName(
+            "Scalable Vector Extension Registers"
+        )
+        return sve_registers.GetChildMemberWithName("vg").GetValueAsUnsigned()
+
+    def read_svg(self):
+        process = self.dbg.GetSelectedTarget().GetProcess()
+        registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters()
+        sve_registers = registerSets.GetFirstValueByName(
+            "Scalable Matrix Extension Registers"
+        )
+        return sve_registers.GetChildMemberWithName("svg").GetValueAsUnsigned()
+
+    def make_za_value(self, vl, generator):
+        # Generate a vector value string "{0x00 0x01....}".
+        rows = []
+        for row in range(vl):
+            byte = "0x{:02x}".format(generator(row))
+            rows.append(" ".join([byte] * vl))
+        return "{" + " ".join(rows) + "}"
+
+    def check_za(self, vl):
+        # We expect an increasing value starting at 1. Row 0=1, row 1 = 2, etc.
+        self.expect(
+            "register read za", substrs=[self.make_za_value(vl, lambda row: row + 1)]
+        )
+
+    def check_za_disabled(self, vl):
+        # When ZA is disabled, lldb will show ZA as all 0s.
+        self.expect("register read za", substrs=[self.make_za_value(vl, lambda row: 0)])
+
+    def za_expr_test_impl(self, sve_mode, za_state, swap_start_vl):
+        if not self.isAArch64SME():
+            self.skipTest("SME must be present.")
+
+        supported_svg = self.get_supported_svg()
+        if len(supported_svg) < 2:
+            self.skipTest("Target must support at least 2 streaming vector lengths.")
+
+        # vg is in units of 8 bytes.
+        start_vl = supported_svg[1] * 8
+        other_vl = supported_svg[2] * 8
+
+        if swap_start_vl:
+            start_vl, other_vl = other_vl, start_vl
+
+        self.line = line_number("main.c", "// Set a break point here.")
+
+        exe = self.getBuildArtifact("a.out")
+        self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET)
+        self.runCmd(
+            "settings set target.run-args {} {} {} {}".format(
+                sve_mode, za_state, start_vl, other_vl
+            )
+        )
+
+        lldbutil.run_break_set_by_file_and_line(
+            self, "main.c", self.line, num_expected_locations=1
+        )
+        self.runCmd("run", RUN_SUCCEEDED)
+
+        self.expect(
+            "thread backtrace",
+            STOPPED_DUE_TO_BREAKPOINT,
+            substrs=["stop reason = breakpoint 1."],
+        )
+
+        exprs = [
+            "expr_disable_za",
+            "expr_enable_za",
+            "expr_start_vl",
+            "expr_other_vl",
+            "expr_enable_sm",
+            "expr_disable_sm",
+        ]
+
+        # This may be the streaming or non-streaming vg. All that matters is
+        # that it is saved and restored, remaining constant throughout.
+        start_vg = self.read_vg()
+
+        # Check SVE registers to make sure that combination of scaling SVE
+        # and scaling ZA works properly. This is a brittle check, but failures
+        # are likely to be catastrophic when they do happen anyway.
+        sve_reg_names = "ffr {} {}".format(
+            " ".join(["z{}".format(n) for n in range(32)]),
+            " ".join(["p{}".format(n) for n in range(16)]),
+        )
+        self.runCmd("register read " + sve_reg_names)
+        sve_values = self.res.GetOutput()
+
+        def check_regs():
+            if za_state == ZA.Enabled:
+                self.check_za(start_vl)
+            else:
+                self.check_za_disabled(start_vl)
+
+            # svg and vg are in units of 8 bytes.
+            self.assertEqual(start_vl, self.read_svg() * 8)
+            self.assertEqual(start_vg, self.read_vg())
+
+            self.expect("register read " + sve_reg_names, substrs=[sve_values])
+
+        for expr in exprs:
+            expr_cmd = "expression {}()".format(expr)
+
+            # We do this twice because there were issues in development where
+            # using data stored by a previous WriteAllRegisterValues would crash
+            # the second time around.
+            self.runCmd(expr_cmd)
+            check_regs()
+            self.runCmd(expr_cmd)
+            check_regs()
+
+        # Run them in sequence to make sure there is no state lingering between
+        # them after a restore.
+        for expr in exprs:
+            self.runCmd("expression {}()".format(expr))
+            check_regs()
+
+        for expr in reversed(exprs):
+            self.runCmd("expression {}()".format(expr))
+            check_regs()
+
+    # These tests start with the 1st supported SVL and change to the 2nd
+    # supported SVL as needed.
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_ssve_za_enabled(self):
+        self.za_expr_test_impl(Mode.SSVE, ZA.Enabled, False)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_ssve_za_disabled(self):
+        self.za_expr_test_impl(Mode.SSVE, ZA.Disabled, False)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_sve_za_enabled(self):
+        self.za_expr_test_impl(Mode.SVE, ZA.Enabled, False)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_sve_za_disabled(self):
+        self.za_expr_test_impl(Mode.SVE, ZA.Disabled, False)
+
+    # These tests start in the 2nd supported SVL and change to the 1st supported
+    # SVL as needed.
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_ssve_za_enabled_
diff erent_vl(self):
+        self.za_expr_test_impl(Mode.SSVE, ZA.Enabled, True)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_ssve_za_disabled_
diff erent_vl(self):
+        self.za_expr_test_impl(Mode.SSVE, ZA.Disabled, True)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_sve_za_enabled_
diff erent_vl(self):
+        self.za_expr_test_impl(Mode.SVE, ZA.Enabled, True)
+
+    @no_debug_info_test
+    @skipIf(archs=no_match(["aarch64"]))
+    @skipIf(oslist=no_match(["linux"]))
+    def test_za_expr_sve_za_disabled_
diff erent_vl(self):
+        self.za_expr_test_impl(Mode.SVE, ZA.Disabled, True)

diff  --git a/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/main.c b/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/main.c
new file mode 100644
index 000000000000000..a8434787a5a1235
--- /dev/null
+++ b/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/main.c
@@ -0,0 +1,225 @@
+#include <stdbool.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/prctl.h>
+
+// Important details for this program:
+// * Making a syscall will disable streaming mode if it is active.
+// * Changing the vector length will make streaming mode and ZA inactive.
+// * ZA can be active independent of streaming mode.
+// * ZA's size is the streaming vector length squared.
+
+#ifndef PR_SME_SET_VL
+#define PR_SME_SET_VL 63
+#endif
+
+#ifndef PR_SME_GET_VL
+#define PR_SME_GET_VL 64
+#endif
+
+#ifndef PR_SME_VL_LEN_MASK
+#define PR_SME_VL_LEN_MASK 0xffff
+#endif
+
+#define SM_INST(c) asm volatile("msr s0_3_c4_c" #c "_3, xzr")
+#define SMSTART SM_INST(7)
+#define SMSTART_SM SM_INST(3)
+#define SMSTART_ZA SM_INST(5)
+#define SMSTOP SM_INST(6)
+#define SMSTOP_SM SM_INST(2)
+#define SMSTOP_ZA SM_INST(4)
+
+int start_vl = 0;
+int other_vl = 0;
+
+void write_sve_regs() {
+  // We assume the smefa64 feature is present, which allows ffr access
+  // in streaming mode.
+  asm volatile("setffr\n\t");
+  asm volatile("ptrue p0.b\n\t");
+  asm volatile("ptrue p1.h\n\t");
+  asm volatile("ptrue p2.s\n\t");
+  asm volatile("ptrue p3.d\n\t");
+  asm volatile("pfalse p4.b\n\t");
+  asm volatile("ptrue p5.b\n\t");
+  asm volatile("ptrue p6.h\n\t");
+  asm volatile("ptrue p7.s\n\t");
+  asm volatile("ptrue p8.d\n\t");
+  asm volatile("pfalse p9.b\n\t");
+  asm volatile("ptrue p10.b\n\t");
+  asm volatile("ptrue p11.h\n\t");
+  asm volatile("ptrue p12.s\n\t");
+  asm volatile("ptrue p13.d\n\t");
+  asm volatile("pfalse p14.b\n\t");
+  asm volatile("ptrue p15.b\n\t");
+
+  asm volatile("cpy  z0.b, p0/z, #1\n\t");
+  asm volatile("cpy  z1.b, p5/z, #2\n\t");
+  asm volatile("cpy  z2.b, p10/z, #3\n\t");
+  asm volatile("cpy  z3.b, p15/z, #4\n\t");
+  asm volatile("cpy  z4.b, p0/z, #5\n\t");
+  asm volatile("cpy  z5.b, p5/z, #6\n\t");
+  asm volatile("cpy  z6.b, p10/z, #7\n\t");
+  asm volatile("cpy  z7.b, p15/z, #8\n\t");
+  asm volatile("cpy  z8.b, p0/z, #9\n\t");
+  asm volatile("cpy  z9.b, p5/z, #10\n\t");
+  asm volatile("cpy  z10.b, p10/z, #11\n\t");
+  asm volatile("cpy  z11.b, p15/z, #12\n\t");
+  asm volatile("cpy  z12.b, p0/z, #13\n\t");
+  asm volatile("cpy  z13.b, p5/z, #14\n\t");
+  asm volatile("cpy  z14.b, p10/z, #15\n\t");
+  asm volatile("cpy  z15.b, p15/z, #16\n\t");
+  asm volatile("cpy  z16.b, p0/z, #17\n\t");
+  asm volatile("cpy  z17.b, p5/z, #18\n\t");
+  asm volatile("cpy  z18.b, p10/z, #19\n\t");
+  asm volatile("cpy  z19.b, p15/z, #20\n\t");
+  asm volatile("cpy  z20.b, p0/z, #21\n\t");
+  asm volatile("cpy  z21.b, p5/z, #22\n\t");
+  asm volatile("cpy  z22.b, p10/z, #23\n\t");
+  asm volatile("cpy  z23.b, p15/z, #24\n\t");
+  asm volatile("cpy  z24.b, p0/z, #25\n\t");
+  asm volatile("cpy  z25.b, p5/z, #26\n\t");
+  asm volatile("cpy  z26.b, p10/z, #27\n\t");
+  asm volatile("cpy  z27.b, p15/z, #28\n\t");
+  asm volatile("cpy  z28.b, p0/z, #29\n\t");
+  asm volatile("cpy  z29.b, p5/z, #30\n\t");
+  asm volatile("cpy  z30.b, p10/z, #31\n\t");
+  asm volatile("cpy  z31.b, p15/z, #32\n\t");
+}
+
+// Write something 
diff erent so we will know if we didn't restore them
+// correctly.
+void write_sve_regs_expr() {
+  asm volatile("pfalse p0.b\n\t");
+  asm volatile("wrffr p0.b\n\t");
+  asm volatile("pfalse p1.b\n\t");
+  asm volatile("pfalse p2.b\n\t");
+  asm volatile("pfalse p3.b\n\t");
+  asm volatile("ptrue p4.b\n\t");
+  asm volatile("pfalse p5.b\n\t");
+  asm volatile("pfalse p6.b\n\t");
+  asm volatile("pfalse p7.b\n\t");
+  asm volatile("pfalse p8.b\n\t");
+  asm volatile("ptrue p9.b\n\t");
+  asm volatile("pfalse p10.b\n\t");
+  asm volatile("pfalse p11.b\n\t");
+  asm volatile("pfalse p12.b\n\t");
+  asm volatile("pfalse p13.b\n\t");
+  asm volatile("ptrue p14.b\n\t");
+  asm volatile("pfalse p15.b\n\t");
+
+  asm volatile("cpy  z0.b, p0/z, #2\n\t");
+  asm volatile("cpy  z1.b, p5/z, #3\n\t");
+  asm volatile("cpy  z2.b, p10/z, #4\n\t");
+  asm volatile("cpy  z3.b, p15/z, #5\n\t");
+  asm volatile("cpy  z4.b, p0/z, #6\n\t");
+  asm volatile("cpy  z5.b, p5/z, #7\n\t");
+  asm volatile("cpy  z6.b, p10/z, #8\n\t");
+  asm volatile("cpy  z7.b, p15/z, #9\n\t");
+  asm volatile("cpy  z8.b, p0/z, #10\n\t");
+  asm volatile("cpy  z9.b, p5/z, #11\n\t");
+  asm volatile("cpy  z10.b, p10/z, #12\n\t");
+  asm volatile("cpy  z11.b, p15/z, #13\n\t");
+  asm volatile("cpy  z12.b, p0/z, #14\n\t");
+  asm volatile("cpy  z13.b, p5/z, #15\n\t");
+  asm volatile("cpy  z14.b, p10/z, #16\n\t");
+  asm volatile("cpy  z15.b, p15/z, #17\n\t");
+  asm volatile("cpy  z16.b, p0/z, #18\n\t");
+  asm volatile("cpy  z17.b, p5/z, #19\n\t");
+  asm volatile("cpy  z18.b, p10/z, #20\n\t");
+  asm volatile("cpy  z19.b, p15/z, #21\n\t");
+  asm volatile("cpy  z20.b, p0/z, #22\n\t");
+  asm volatile("cpy  z21.b, p5/z, #23\n\t");
+  asm volatile("cpy  z22.b, p10/z, #24\n\t");
+  asm volatile("cpy  z23.b, p15/z, #25\n\t");
+  asm volatile("cpy  z24.b, p0/z, #26\n\t");
+  asm volatile("cpy  z25.b, p5/z, #27\n\t");
+  asm volatile("cpy  z26.b, p10/z, #28\n\t");
+  asm volatile("cpy  z27.b, p15/z, #29\n\t");
+  asm volatile("cpy  z28.b, p0/z, #30\n\t");
+  asm volatile("cpy  z29.b, p5/z, #31\n\t");
+  asm volatile("cpy  z30.b, p10/z, #32\n\t");
+  asm volatile("cpy  z31.b, p15/z, #33\n\t");
+}
+
+void set_za_register(int svl, int value_offset) {
+#define MAX_VL_BYTES 256
+  uint8_t data[MAX_VL_BYTES];
+
+  // ldr za will actually wrap the selected vector row, by the number of rows
+  // you have. So setting one that didn't exist would actually set one that did.
+  // That's why we need the streaming vector length here.
+  for (int i = 0; i < svl; ++i) {
+    memset(data, i + value_offset, MAX_VL_BYTES);
+    // Each one of these loads a VL sized row of ZA.
+    asm volatile("mov w12, %w0\n\t"
+                 "ldr za[w12, 0], [%1]\n\t" ::"r"(i),
+                 "r"(&data)
+                 : "w12");
+  }
+}
+
+void expr_disable_za() {
+  SMSTOP_ZA;
+  write_sve_regs_expr();
+}
+
+void expr_enable_za() {
+  SMSTART_ZA;
+  set_za_register(start_vl, 2);
+  write_sve_regs_expr();
+}
+
+void expr_start_vl() {
+  prctl(PR_SME_SET_VL, start_vl);
+  SMSTART_ZA;
+  set_za_register(start_vl, 4);
+  write_sve_regs_expr();
+}
+
+void expr_other_vl() {
+  prctl(PR_SME_SET_VL, other_vl);
+  SMSTART_ZA;
+  set_za_register(other_vl, 5);
+  write_sve_regs_expr();
+}
+
+void expr_enable_sm() {
+  SMSTART_SM;
+  write_sve_regs_expr();
+}
+
+void expr_disable_sm() {
+  SMSTOP_SM;
+  write_sve_regs_expr();
+}
+
+int main(int argc, char *argv[]) {
+  // We expect to get:
+  // * whether to enable streaming mode
+  // * whether to enable ZA
+  // * what the starting VL should be
+  // * what the other VL should be
+  if (argc != 5)
+    return 1;
+
+  bool ssve = argv[1][0] == '1';
+  bool za = argv[2][0] == '1';
+  start_vl = atoi(argv[3]);
+  other_vl = atoi(argv[4]);
+
+  prctl(PR_SME_SET_VL, start_vl);
+
+  if (ssve)
+    SMSTART_SM;
+
+  if (za) {
+    SMSTART_ZA;
+    set_za_register(start_vl, 1);
+  }
+
+  write_sve_regs();
+
+  return 0; // Set a break point here.
+}


        


More information about the lldb-commits mailing list