[Mlir-commits] [mlir] [mlir][ArmSME] Make use of backend function attributes for enabling ZA storage (PR #71044)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 6 04:25:20 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

Previously, we were inserting za.enable/disable intrinsics for functions with the "arm_za" attribute (at the MLIR level), rather than using the backend attributes. This was done to avoid a dependency on the SME ABI functions from compiler-rt (which have only recently been implemented).

Doing things this way did have correctness issues, for example, calling a streaming-mode function from another streaming-mode function (both with ZA enabled) would lead to ZA being disabled after returning to the caller (where it should still be enabled). Fixing issues like this would require re-doing the ABI work already done in the backend within MLIR.

Instead, this patch switches to use the "arm_new_za" (backend) attribute for enabling ZA for an MLIR function. For the integration tests, this requires some way of linking the SME ABI functions. This is done via the `%arm_sme_abi_shlib` lit substitution. By default, this expands to a stub implementation of the SME ABI functions, but this can be overridden by providing the `ARM_SME_ABI_ROUTINES_SHLIB` CMake cache variable (pointing it at an alternative implementation). For now, the ArmSME integration tests pass with just stubs, as we don't make use of nested ZA-enabled calls.

A future patch may add an option to compiler-rt to build the SME builtins into a standalone shared library to allow easily building/testing with the actual implementation. 

---

Patch is 31.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71044.diff


25 Files Affected:

- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td (-3) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (+7-3) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+14-6) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+1) 
- (modified) mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp (+12-13) 
- (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+1-54) 
- (added) mlir/lib/ExecutionEngine/ArmSMEStub.cpp (+48) 
- (modified) mlir/lib/ExecutionEngine/CMakeLists.txt (+5) 
- (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+6-1) 
- (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+3) 
- (modified) mlir/test/CMakeLists.txt (+6) 
- (modified) mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir (+3-3) 
- (modified) mlir/test/Dialect/ArmSME/enable-arm-za.mlir (+10-12) 
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir (+2-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/arm-sme.mlir (-11) 
- (modified) mlir/test/lit.cfg.py (+18-4) 
- (modified) mlir/test/lit.site.cfg.py.in (+1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index e369ef203ad39d6..9f4ef24366b09db 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -131,7 +131,4 @@ def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
 def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
 def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
 
-def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
-def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
-
 #endif // ARMSME_INTRINSIC_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index ab5c179f2dd7790..95b016e87921a67 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -24,15 +24,19 @@ namespace arm_sme {
 // the function interface (ABI) and the caller manages PSTATE.SM on entry/exit.
 // In a locally streaming function PSTATE.SM is kept internal and the callee
 // manages it on entry/exit.
-enum class ArmStreaming { Default = 0, Locally = 1 };
+enum class ArmStreamingMode { Default = 0, Locally = 1 };
+
+// TODO: Add other ZA modes.
+// https://arm-software.github.io/acle/main/acle.html#sme-attributes-relating-to-za
+enum class ArmZaMode { Disabled = 0, New = 1 };
 
 #define GEN_PASS_DECL
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
 
 /// Pass to enable Armv9 Streaming SVE mode.
 std::unique_ptr<Pass>
-createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default,
-                             const bool enableZA = false);
+createEnableArmStreamingPass(const ArmStreamingMode = ArmStreamingMode::Default,
+                             const ArmZaMode = ArmZaMode::Disabled);
 
 /// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
 std::unique_ptr<Pass> createTileAllocationPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 3fa1b43eb9e67e0..2ea5c6947754e65 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -22,19 +22,27 @@ def EnableArmStreaming
   }];
   let constructor = "mlir::arm_sme::createEnableArmStreamingPass()";
   let options = [
-    Option<"mode", "mode", "mlir::arm_sme::ArmStreaming",
-          /*default=*/"mlir::arm_sme::ArmStreaming::Default",
+    Option<"streamingMode", "streaming-mode", "mlir::arm_sme::ArmStreamingMode",
+          /*default=*/"mlir::arm_sme::ArmStreamingMode::Default",
           "Select how streaming-mode is managed at the function-level.",
           [{::llvm::cl::values(
-                clEnumValN(mlir::arm_sme::ArmStreaming::Default, "default",
+                clEnumValN(mlir::arm_sme::ArmStreamingMode::Default, "default",
 						   "Streaming mode is part of the function interface "
 						   "(ABI), caller manages PSTATE.SM on entry/exit."),
-                clEnumValN(mlir::arm_sme::ArmStreaming::Locally, "locally",
+                clEnumValN(mlir::arm_sme::ArmStreamingMode::Locally, "locally",
 						   "Streaming mode is internal to the function, callee "
 						   "manages PSTATE.SM on entry/exit.")
           )}]>,
-    Option<"enableZA", "enable-za", "bool", /*default=*/"false",
-           "Enable ZA storage array.">,
+    Option<"zaMode", "za-mode", "mlir::arm_sme::ArmZaMode",
+           /*default=*/"mlir::arm_sme::ArmZaMode::Disabled",
+           "Select how ZA-storage is managed at the function-level.",
+           [{::llvm::cl::values(
+                clEnumValN(mlir::arm_sme::ArmZaMode::Disabled, "disabled",
+					 	   "ZA storage is disabled."),
+                clEnumValN(mlir::arm_sme::ArmZaMode::New, "new",
+					 	   "The function has ZA state. The ZA state is created on entry "
+               "and destroyed on exit.")
+           )}]>
   ];
   let dependentDialects = ["func::FuncDialect"];
 }
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 638c31b39682ea6..dfc0588e92e44ed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1415,6 +1415,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     DefaultValuedAttr<Visibility, "mlir::LLVM::Visibility::Default">:$visibility_,
     OptionalAttr<UnitAttr>:$arm_streaming,
     OptionalAttr<UnitAttr>:$arm_locally_streaming,
+    OptionalAttr<UnitAttr>:$arm_new_za,
     OptionalAttr<StrAttr>:$section,
     OptionalAttr<UnnamedAddr>:$unnamed_addr,
     OptionalAttr<I64Attr>:$alignment,
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index 1d3a090e861013b..1b59b6d907235b4 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -51,26 +51,26 @@ using namespace mlir::arm_sme;
 
 static constexpr char kArmStreamingAttr[] = "arm_streaming";
 static constexpr char kArmLocallyStreamingAttr[] = "arm_locally_streaming";
-static constexpr char kArmZAAttr[] = "arm_za";
+static constexpr char kArmNewZAAttr[] = "arm_new_za";
 static constexpr char kEnableArmStreamingIgnoreAttr[] =
     "enable_arm_streaming_ignore";
 
 namespace {
 struct EnableArmStreamingPass
     : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
-  EnableArmStreamingPass(ArmStreaming mode, bool enableZA) {
-    this->mode = mode;
-    this->enableZA = enableZA;
+  EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode) {
+    this->streamingMode = streamingMode;
+    this->zaMode = zaMode;
   }
   void runOnOperation() override {
     if (getOperation()->getAttr(kEnableArmStreamingIgnoreAttr))
       return;
     StringRef attr;
-    switch (mode) {
-    case ArmStreaming::Default:
+    switch (streamingMode) {
+    case ArmStreamingMode::Default:
       attr = kArmStreamingAttr;
       break;
-    case ArmStreaming::Locally:
+    case ArmStreamingMode::Locally:
       attr = kArmLocallyStreamingAttr;
       break;
     }
@@ -80,14 +80,13 @@ struct EnableArmStreamingPass
     // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
     // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
     // supporting this later.
-    if (enableZA)
-      getOperation()->setAttr(kArmZAAttr, UnitAttr::get(&getContext()));
+    if (zaMode == ArmZaMode::New)
+      getOperation()->setAttr(kArmNewZAAttr, UnitAttr::get(&getContext()));
   }
 };
 } // namespace
 
-std::unique_ptr<Pass>
-mlir::arm_sme::createEnableArmStreamingPass(const ArmStreaming mode,
-                                            const bool enableZA) {
-  return std::make_unique<EnableArmStreamingPass>(mode, enableZA);
+std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
+    const ArmStreamingMode streamingMode, const ArmZaMode zaMode) {
+  return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode);
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index d1a54658a595bf3..6078b3f2c5e4708 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -21,33 +21,6 @@ using namespace mlir;
 using namespace mlir::arm_sme;
 
 namespace {
-/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
-/// ops to enable the ZA storage array.
-struct EnableZAPattern : public OpRewritePattern<func::FuncOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(func::FuncOp op,
-                                PatternRewriter &rewriter) const final {
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPointToStart(&op.front());
-    rewriter.create<arm_sme::aarch64_sme_za_enable>(op->getLoc());
-    rewriter.updateRootInPlace(op, [] {});
-    return success();
-  }
-};
-
-/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to
-/// disable the ZA storage array.
-struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(func::ReturnOp op,
-                                PatternRewriter &rewriter) const final {
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(op);
-    rewriter.create<arm_sme::aarch64_sme_za_disable>(op->getLoc());
-    rewriter.updateRootInPlace(op, [] {});
-    return success();
-  }
-};
 
 /// Lower 'arm_sme.zero' to SME intrinsics.
 ///
@@ -678,39 +651,13 @@ void mlir::configureArmSMELegalizeForExportTarget(
       arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
       arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
       arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
-      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
-      arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
+      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
   target.addLegalOp<GetTileID>();
   target.addIllegalOp<vector::OuterProductOp>();
-
-  // Mark 'func.func' ops as legal if either:
-  //   1. no 'arm_za' function attribute is present.
-  //   2. the 'arm_za' function attribute is present and the first op in the
-  //      function is an 'arm_sme::aarch64_sme_za_enable' intrinsic.
-  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
-    if (funcOp.isDeclaration())
-      return true;
-    auto firstOp = funcOp.getBody().front().begin();
-    return !funcOp->hasAttr("arm_za") ||
-           isa<arm_sme::aarch64_sme_za_enable>(firstOp);
-  });
-
-  // Mark 'func.return' ops as legal if either:
-  //   1. no 'arm_za' function attribute is present.
-  //   2. the 'arm_za' function attribute is present and there's a preceding
-  //      'arm_sme::aarch64_sme_za_disable' intrinsic.
-  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp returnOp) {
-    bool hasDisableZA = false;
-    auto funcOp = returnOp->getParentOp();
-    funcOp->walk<WalkOrder::PreOrder>(
-        [&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; });
-    return !funcOp->hasAttr("arm_za") || hasDisableZA;
-  });
 }
 
 void mlir::populateArmSMELegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<DisableZAPattern, EnableZAPattern>(patterns.getContext());
   patterns.add<
       LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
       MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
diff --git a/mlir/lib/ExecutionEngine/ArmSMEStub.cpp b/mlir/lib/ExecutionEngine/ArmSMEStub.cpp
new file mode 100644
index 000000000000000..f9f64ad5e5ac81c
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/ArmSMEStub.cpp
@@ -0,0 +1,48 @@
+//===- ArmSMEStub.cpp - ArmSME ABI routine stubs --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Support/Compiler.h"
+#include <cstdint>
+#include <iostream>
+
+// The actual implementation of these routines is in:
+// compiler-rt/lib/builtins/aarch64/sme-abi.S. These stubs allow the current
+// ArmSME tests to run without depending on compiler-rt. This works as we don't
+// rely on nested ZA-enabled calls at the moment. The use of these stubs can be
+// overridden by setting the ARM_SME_ABI_ROUTINES_SHLIB CMake cache variable to
+// a path to an alternate implementation.
+
+extern "C" {
+
+bool LLVM_ATTRIBUTE_WEAK __aarch64_sme_accessible() {
+  // The ArmSME tests are run within an emulator so we assume SME is available.
+  return true;
+}
+
+struct sme_state {
+  int64_t x0;
+  int64_t x1;
+};
+
+sme_state LLVM_ATTRIBUTE_WEAK __arm_sme_state() {
+  std::cerr << "[warning] __arm_sme_state() stubbed!\n";
+  return sme_state{};
+}
+
+void LLVM_ATTRIBUTE_WEAK __arm_tpidr2_restore() {
+  std::cerr << "[warning] __arm_tpidr2_restore() stubbed!\n";
+}
+
+void LLVM_ATTRIBUTE_WEAK __arm_tpidr2_save() {
+  std::cerr << "[warning] __arm_tpidr2_save() stubbed!\n";
+}
+
+void LLVM_ATTRIBUTE_WEAK __arm_za_disable() {
+  std::cerr << "[warning] __arm_za_disable() stubbed!\n";
+}
+}
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index fdc797763ae3a41..aa8cb9728e8cdd0 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -2,6 +2,7 @@
 # is a big dependency which most don't need.
 
 set(LLVM_OPTIONAL_SOURCES
+  ArmSMEStub.cpp
   AsyncRuntime.cpp
   CRunnerUtils.cpp
   CudaRuntimeWrappers.cpp
@@ -177,6 +178,10 @@ if(LLVM_ENABLE_PIC)
     target_link_options(mlir_async_runtime PRIVATE "-Wl,-exclude-libs,ALL")
   endif()
 
+  add_mlir_library(mlir_arm_sme_abi_stubs
+    SHARED
+    ArmSMEStub.cpp)
+
   if(MLIR_ENABLE_CUDA_RUNNER)
     # Configure CUDA support. Using check_language first allows us to give a
     # custom error message.
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index e3562049cd81c76..b4c56f995234cb3 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1583,7 +1583,8 @@ static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
     // explicit attribute.
     // Also skip the vscale_range, it is also an explicit attribute.
     if (attrName == "aarch64_pstate_sm_enabled" ||
-        attrName == "aarch64_pstate_sm_body" || attrName == "vscale_range")
+        attrName == "aarch64_pstate_sm_body" ||
+        attrName == "aarch64_pstate_za_new" || attrName == "vscale_range")
       continue;
 
     if (attr.isStringAttribute()) {
@@ -1623,6 +1624,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
     funcOp.setArmStreaming(true);
   else if (func->hasFnAttribute("aarch64_pstate_sm_body"))
     funcOp.setArmLocallyStreaming(true);
+
+  if (func->hasFnAttribute("aarch64_pstate_za_new"))
+    funcOp.setArmNewZa(true);
+
   llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange);
   if (attr.isValid()) {
     MLIRContext *context = funcOp.getContext();
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 7312388bc9b4dd2..e6247e12ecb38ac 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -890,6 +890,9 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   else if (func.getArmLocallyStreaming())
     llvmFunc->addFnAttr("aarch64_pstate_sm_body");
 
+  if (func.getArmNewZa())
+    llvmFunc->addFnAttr("aarch64_pstate_za_new");
+
   if (auto attr = func.getVscaleRange())
     llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
         getLLVMContext(), attr->getMinRange().getInt(),
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index d81f3c4b1e20c5a..c7b7debffc56ab9 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -28,6 +28,8 @@ if (MLIR_INCLUDE_INTEGRATION_TESTS)
       "If arch-specific Arm integration tests run emulated, find Arm native utility libraries in this directory.")
   set(MLIR_GPU_COMPILATION_TEST_FORMAT "fatbin" CACHE STRING
       "The GPU compilation format used by the tests.")
+  set(ARM_SME_ABI_ROUTINES_SHLIB "" CACHE STRING
+      "Path to a shared library containing Arm SME ABI routines, required for Arm SME integration tests.")
   option(MLIR_RUN_AMX_TESTS "Run AMX tests.")
   option(MLIR_RUN_X86VECTOR_TESTS "Run X86Vector tests.")
   option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.")
@@ -139,6 +141,10 @@ if(MLIR_ENABLE_ROCM_RUNNER)
   list(APPEND MLIR_TEST_DEPENDS mlir_rocm_runtime)
 endif()
 
+if (MLIR_RUN_ARM_SME_TESTS)
+  list(APPEND MLIR_TEST_DEPENDS mlir_arm_sme_abi_stubs)
+endif()
+
 list(APPEND MLIR_TEST_DEPENDS MLIRUnitTests)
 
 if(LLVM_BUILD_EXAMPLES)
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index e7bbe8c0047687d..2ec6f4090dff0c2 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -1,13 +1,13 @@
 // RUN: mlir-opt %s -enable-arm-streaming -verify-diagnostics | FileCheck %s
-// RUN: mlir-opt %s -enable-arm-streaming=mode=locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
-// RUN: mlir-opt %s -enable-arm-streaming=enable-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
 
 // CHECK-LABEL: @arm_streaming
 // CHECK-SAME: attributes {arm_streaming}
 // CHECK-LOCALLY-LABEL: @arm_streaming
 // CHECK-LOCALLY-SAME: attributes {arm_locally_streaming}
 // CHECK-ENABLE-ZA-LABEL: @arm_streaming
-// CHECK-ENABLE-ZA-SAME: attributes {arm_streaming, arm_za}
+// CHECK-ENABLE-ZA-SAME: attributes {arm_new_za, arm_streaming}
 func.func @arm_streaming() { return }
 
 // CHECK-LABEL: @not_arm_streaming
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
index d415b19f6fa94cf..8631721ef61bc77 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
@@ -1,18 +1,16 @@
-// RUN: mlir-opt %s -enable-arm-streaming=enable-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
 // RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA
 // RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING
 
 // CHECK-LABEL: @declaration
 func.func private @declaration()
 
-// CHECK-LABEL: @arm_za
-func.func @arm_za() {
-  // ENABLE-ZA: arm_sme.intr.za.enable
-  // ENABLE-ZA-NEXT: arm_sme.intr.za.disable
-  // ENABLE-ZA-NEXT: return
-  // DISABLE-ZA-NOT: arm_sme.intr.za.enable
-  // DISABLE-ZA-NOT: arm_sme.intr.za.disable
-  // NO-ARM-STREAMING-NOT: arm_sme.intr.za.enable
-  // NO-ARM-STREAMING-NOT: arm_sme.intr.za.disable
-  return
-}
+// ENABLE-ZA-LABEL: @arm_new_za
+// ENABLE-ZA-SAME: attributes {arm_new_za, arm_streaming}
+// DISABLE-ZA-LABEL: @arm_new_za
+// DISABLE-ZA-NOT: arm_new_za
+// DISABLE-ZA-SAME: attributes {arm_streaming}
+// NO-ARM-STREAMING-LABEL: @arm_new_za
+// NO-ARM-STREAMING-NOT: arm_new_za
+// NO-ARM-STREAMING-NOT: arm_streaming
+func.func @arm_new_za() { return }
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 131cbc05a9857e0..1d9f3977389c850 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -3,14 +3,14 @@
 // RUN:   -test-transform-dialect-erase-schedule \
 // RUN:   -lower-vector-mask \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" \
-// RUN:   -enable-arm-streaming="mode=locally enable-za" \
+// RUN:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // RUN:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=entry -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \
-// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
+// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \
 // RUN: FileCheck %s
 
 func.func @entry() {
diff --git a/mlir/...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/71044


More information about the Mlir-commits mailing list