[Mlir-commits] [mlir] [mlir][ArmSME] Make use of backend function attributes for enabling ZA storage (PR #71044)

Benjamin Maxwell llvmlistbot at llvm.org
Fri Nov 10 04:32:09 PST 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/71044

>From 7e9249346d916d4fb5882aa0d81f00a9c4c43f33 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 1 Nov 2023 15:57:18 +0000
Subject: [PATCH 1/6] [mlir][ArmSME] Make use of backend function attributes
 for enabling ZA storage

Previously, we were inserting za.enable/disable intrinsics for functions
with the "arm_za" attribute (at the MLIR level), rather than using the
backend attributes. This was done to avoid a dependency on the SME ABI
functions from compiler-rt (which have only recently been implemented).

Doing things this way did have correctness issues, for example, calling
a streaming-mode function from another streaming-mode function (both
with ZA enabled) would lead to ZA being disabled after returning to the
caller (where it should still be enabled). Fixing issues like this would
require re-doing the ABI work already done in the backend within MLIR.

Instead, this patch switches to use the "arm_new_za" (backend) attribute
for enabling ZA for a MLIR function. For the integration tests this
requires some way of linking the SME ABI functions. This has been done
by adding a mlir_arm_sme_runtime library, which includes the
implementation from compiler-rt, which can then be linked via the
`-shared-libs` flag.

To build the mlir_arm_sme_runtime the target has to be AArch64, and
the host compiler must be able to assemble SME instructions (this is
supported in recent versions of clang). Note that the host being AArch64
is already assumed by the integration tests linking other runtime
libraries (e.g. mlir_c_runner_utils).
---
 .../Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td   |  3 -
 .../mlir/Dialect/ArmSME/Transforms/Passes.h   | 10 +++-
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  | 20 +++++--
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  1 +
 .../ArmSME/Transforms/EnableArmStreaming.cpp  | 25 ++++-----
 .../Transforms/LegalizeForLLVMExport.cpp      | 55 +------------------
 mlir/lib/ExecutionEngine/ArmSMEStub.cpp       | 10 ++++
 mlir/lib/ExecutionEngine/CMakeLists.txt       | 34 ++++++++++++
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       |  7 ++-
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  3 +
 mlir/test/CMakeLists.txt                      |  4 ++
 .../Dialect/ArmSME/enable-arm-streaming.mlir  |  6 +-
 mlir/test/Dialect/ArmSME/enable-arm-za.mlir   | 22 ++++----
 .../Dialect/Linalg/CPU/ArmSME/fill-2d.mlir    |  4 +-
 .../CPU/ArmSME/load-store-128-bit-tile.mlir   |  4 +-
 .../Vector/CPU/ArmSME/test-load-vertical.mlir |  4 +-
 .../CPU/ArmSME/test-outerproduct-f32.mlir     |  4 +-
 .../CPU/ArmSME/test-outerproduct-f64.mlir     |  4 +-
 .../Vector/CPU/ArmSME/test-transpose.mlir     |  4 +-
 .../Dialect/Vector/CPU/ArmSME/tile_fill.mlir  |  4 +-
 .../Vector/CPU/ArmSME/vector-load-store.mlir  |  4 +-
 .../Dialect/Vector/CPU/ArmSME/vector-ops.mlir |  4 +-
 mlir/test/Target/LLVMIR/arm-sme.mlir          | 11 ----
 mlir/test/lit.cfg.py                          |  3 +
 24 files changed, 126 insertions(+), 124 deletions(-)
 create mode 100644 mlir/lib/ExecutionEngine/ArmSMEStub.cpp

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index c86a73812a5899c..bcf2466b13a739f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -161,7 +161,4 @@ def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
 def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
 def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
 
-def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
-def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
-
 #endif // ARMSME_INTRINSIC_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index ab5c179f2dd7790..95b016e87921a67 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -24,15 +24,19 @@ namespace arm_sme {
 // the function interface (ABI) and the caller manages PSTATE.SM on entry/exit.
 // In a locally streaming function PSTATE.SM is kept internal and the callee
 // manages it on entry/exit.
-enum class ArmStreaming { Default = 0, Locally = 1 };
+enum class ArmStreamingMode { Default = 0, Locally = 1 };
+
+// TODO: Add other ZA modes.
+// https://arm-software.github.io/acle/main/acle.html#sme-attributes-relating-to-za
+enum class ArmZaMode { Disabled = 0, New = 1 };
 
 #define GEN_PASS_DECL
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
 
 /// Pass to enable Armv9 Streaming SVE mode.
 std::unique_ptr<Pass>
-createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default,
-                             const bool enableZA = false);
+createEnableArmStreamingPass(const ArmStreamingMode = ArmStreamingMode::Default,
+                             const ArmZaMode = ArmZaMode::Disabled);
 
 /// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
 std::unique_ptr<Pass> createTileAllocationPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 3fa1b43eb9e67e0..e24487adc8a5bce 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -22,19 +22,27 @@ def EnableArmStreaming
   }];
   let constructor = "mlir::arm_sme::createEnableArmStreamingPass()";
   let options = [
-    Option<"mode", "mode", "mlir::arm_sme::ArmStreaming",
-          /*default=*/"mlir::arm_sme::ArmStreaming::Default",
+    Option<"streamingMode", "streaming-mode", "mlir::arm_sme::ArmStreamingMode",
+          /*default=*/"mlir::arm_sme::ArmStreamingMode::Default",
           "Select how streaming-mode is managed at the function-level.",
           [{::llvm::cl::values(
-                clEnumValN(mlir::arm_sme::ArmStreaming::Default, "default",
+                clEnumValN(mlir::arm_sme::ArmStreamingMode::Default, "default",
 						   "Streaming mode is part of the function interface "
 						   "(ABI), caller manages PSTATE.SM on entry/exit."),
-                clEnumValN(mlir::arm_sme::ArmStreaming::Locally, "locally",
+                clEnumValN(mlir::arm_sme::ArmStreamingMode::Locally, "locally",
 						   "Streaming mode is internal to the function, callee "
 						   "manages PSTATE.SM on entry/exit.")
           )}]>,
-    Option<"enableZA", "enable-za", "bool", /*default=*/"false",
-           "Enable ZA storage array.">,
+    Option<"zaMode", "za-mode", "mlir::arm_sme::ArmZaMode",
+           /*default=*/"mlir::arm_sme::ArmZaMode::Disabled",
+           "Select how ZA-storage is managed at the function-level.",
+           [{::llvm::cl::values(
+                clEnumValN(mlir::arm_sme::ArmZaMode::Disabled, "disabled",
+					 	   "ZA storage is not enabled."),
+                clEnumValN(mlir::arm_sme::ArmZaMode::New, "new",
+					 	   "The function has ZA state. The ZA state is created on entry "
+               "and destroyed on exit.")
+           )}]>
   ];
   let dependentDialects = ["func::FuncDialect"];
 }
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index ffb79a196db28ad..c4553e013b8b8bb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1387,6 +1387,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     DefaultValuedAttr<Visibility, "mlir::LLVM::Visibility::Default">:$visibility_,
     OptionalAttr<UnitAttr>:$arm_streaming,
     OptionalAttr<UnitAttr>:$arm_locally_streaming,
+    OptionalAttr<UnitAttr>:$arm_new_za,
     OptionalAttr<StrAttr>:$section,
     OptionalAttr<UnnamedAddr>:$unnamed_addr,
     OptionalAttr<I64Attr>:$alignment,
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index 1d3a090e861013b..1b59b6d907235b4 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -51,26 +51,26 @@ using namespace mlir::arm_sme;
 
 static constexpr char kArmStreamingAttr[] = "arm_streaming";
 static constexpr char kArmLocallyStreamingAttr[] = "arm_locally_streaming";
-static constexpr char kArmZAAttr[] = "arm_za";
+static constexpr char kArmNewZAAttr[] = "arm_new_za";
 static constexpr char kEnableArmStreamingIgnoreAttr[] =
     "enable_arm_streaming_ignore";
 
 namespace {
 struct EnableArmStreamingPass
     : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
-  EnableArmStreamingPass(ArmStreaming mode, bool enableZA) {
-    this->mode = mode;
-    this->enableZA = enableZA;
+  EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode) {
+    this->streamingMode = streamingMode;
+    this->zaMode = zaMode;
   }
   void runOnOperation() override {
     if (getOperation()->getAttr(kEnableArmStreamingIgnoreAttr))
       return;
     StringRef attr;
-    switch (mode) {
-    case ArmStreaming::Default:
+    switch (streamingMode) {
+    case ArmStreamingMode::Default:
       attr = kArmStreamingAttr;
       break;
-    case ArmStreaming::Locally:
+    case ArmStreamingMode::Locally:
       attr = kArmLocallyStreamingAttr;
       break;
     }
@@ -80,14 +80,13 @@ struct EnableArmStreamingPass
     // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
     // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
     // supporting this later.
-    if (enableZA)
-      getOperation()->setAttr(kArmZAAttr, UnitAttr::get(&getContext()));
+    if (zaMode == ArmZaMode::New)
+      getOperation()->setAttr(kArmNewZAAttr, UnitAttr::get(&getContext()));
   }
 };
 } // namespace
 
-std::unique_ptr<Pass>
-mlir::arm_sme::createEnableArmStreamingPass(const ArmStreaming mode,
-                                            const bool enableZA) {
-  return std::make_unique<EnableArmStreamingPass>(mode, enableZA);
+std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
+    const ArmStreamingMode streamingMode, const ArmZaMode zaMode) {
+  return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode);
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index d1a54658a595bf3..6078b3f2c5e4708 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -21,33 +21,6 @@ using namespace mlir;
 using namespace mlir::arm_sme;
 
 namespace {
-/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
-/// ops to enable the ZA storage array.
-struct EnableZAPattern : public OpRewritePattern<func::FuncOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(func::FuncOp op,
-                                PatternRewriter &rewriter) const final {
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPointToStart(&op.front());
-    rewriter.create<arm_sme::aarch64_sme_za_enable>(op->getLoc());
-    rewriter.updateRootInPlace(op, [] {});
-    return success();
-  }
-};
-
-/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to
-/// disable the ZA storage array.
-struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(func::ReturnOp op,
-                                PatternRewriter &rewriter) const final {
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(op);
-    rewriter.create<arm_sme::aarch64_sme_za_disable>(op->getLoc());
-    rewriter.updateRootInPlace(op, [] {});
-    return success();
-  }
-};
 
 /// Lower 'arm_sme.zero' to SME intrinsics.
 ///
@@ -678,39 +651,13 @@ void mlir::configureArmSMELegalizeForExportTarget(
       arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
       arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
       arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
-      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
-      arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
+      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
   target.addLegalOp<GetTileID>();
   target.addIllegalOp<vector::OuterProductOp>();
-
-  // Mark 'func.func' ops as legal if either:
-  //   1. no 'arm_za' function attribute is present.
-  //   2. the 'arm_za' function attribute is present and the first op in the
-  //      function is an 'arm_sme::aarch64_sme_za_enable' intrinsic.
-  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
-    if (funcOp.isDeclaration())
-      return true;
-    auto firstOp = funcOp.getBody().front().begin();
-    return !funcOp->hasAttr("arm_za") ||
-           isa<arm_sme::aarch64_sme_za_enable>(firstOp);
-  });
-
-  // Mark 'func.return' ops as legal if either:
-  //   1. no 'arm_za' function attribute is present.
-  //   2. the 'arm_za' function attribute is present and there's a preceding
-  //      'arm_sme::aarch64_sme_za_disable' intrinsic.
-  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp returnOp) {
-    bool hasDisableZA = false;
-    auto funcOp = returnOp->getParentOp();
-    funcOp->walk<WalkOrder::PreOrder>(
-        [&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; });
-    return !funcOp->hasAttr("arm_za") || hasDisableZA;
-  });
 }
 
 void mlir::populateArmSMELegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<DisableZAPattern, EnableZAPattern>(patterns.getContext());
   patterns.add<
       LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
       MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
diff --git a/mlir/lib/ExecutionEngine/ArmSMEStub.cpp b/mlir/lib/ExecutionEngine/ArmSMEStub.cpp
new file mode 100644
index 000000000000000..af741ec716800ce
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/ArmSMEStub.cpp
@@ -0,0 +1,10 @@
+
+#include "llvm/Support/Compiler.h"
+
+extern "C" {
+
+bool LLVM_ATTRIBUTE_WEAK __aarch64_sme_accessible() {
+  // The ArmSME tests are run within an emulator so we assume SME is available.
+  return true;
+}
+}
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index fdc797763ae3a41..6da49b472177684 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -2,6 +2,7 @@
 # is a big dependency which most don't need.
 
 set(LLVM_OPTIONAL_SOURCES
+  ArmSMEStub.cpp
   AsyncRuntime.cpp
   CRunnerUtils.cpp
   CudaRuntimeWrappers.cpp
@@ -177,6 +178,39 @@ if(LLVM_ENABLE_PIC)
     target_link_options(mlir_async_runtime PRIVATE "-Wl,-exclude-libs,ALL")
   endif()
 
+  if (MLIR_RUN_ARM_SME_TESTS)
+    if (NOT DEFINED LLVM_MAIN_SRC_DIR)
+      message(FATAL_ERROR "LLVM_MAIN_SRC_DIR must be provided to build the ArmSME runtime.")
+    endif()
+
+    if (NOT DEFINED MLIR_ARM_SME__CAN_ASSEMBLE_ARM_SME)
+      # This should work on an AArch64 host with a recent version of clang.
+      file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/has_arm_sme_check.S
+      ".arch armv9-a+sme
+      .global main
+      .type   main, %function
+      main: smstart
+      .size   main, .-main")
+      try_compile(MLIR_ARM_SME__CAN_ASSEMBLE_ARM_SME ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/has_arm_sme_check.S)
+    endif()
+
+    if (NOT MLIR_ARM_SME__CAN_ASSEMBLE_ARM_SME)
+      message(FATAL_ERROR "Host compiler must be able to assemble AArch64 SME instructions to build the ArmSME runtime.")
+    endif()
+
+    # FIXME: This is very far from ideal, but enabling compiler-rt in the main
+    # build requires building much more than we need, and does not expose
+    # individual targets (e.g. to build a standalone runtime).
+    add_mlir_library(mlir_arm_sme_runtime
+      SHARED
+      ArmSMEStub.cpp
+      ${LLVM_MAIN_SRC_DIR}/../compiler-rt/lib/builtins/aarch64/sme-abi-init.c
+      ${LLVM_MAIN_SRC_DIR}/../compiler-rt/lib/builtins/aarch64/sme-abi.S
+      EXCLUDE_FROM_LIBMLIR)
+
+    target_compile_definitions(mlir_arm_sme_runtime PRIVATE -DCOMPILER_RT_SHARED_LIB)
+  endif()
+
   if(MLIR_ENABLE_CUDA_RUNNER)
     # Configure CUDA support. Using check_language first allows us to give a
     # custom error message.
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 1f13bdf44992ae5..07074c704e08be8 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1583,7 +1583,8 @@ static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
     // explicit attribute.
     // Also skip the vscale_range, it is also an explicit attribute.
     if (attrName == "aarch64_pstate_sm_enabled" ||
-        attrName == "aarch64_pstate_sm_body" || attrName == "vscale_range")
+        attrName == "aarch64_pstate_sm_body" ||
+        attrName == "aarch64_pstate_za_new" || attrName == "vscale_range")
       continue;
 
     if (attr.isStringAttribute()) {
@@ -1623,6 +1624,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
     funcOp.setArmStreaming(true);
   else if (func->hasFnAttribute("aarch64_pstate_sm_body"))
     funcOp.setArmLocallyStreaming(true);
+
+  if (func->hasFnAttribute("aarch64_pstate_za_new"))
+    funcOp.setArmNewZa(true);
+
   llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange);
   if (attr.isValid()) {
     MLIRContext *context = funcOp.getContext();
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 388ae61958b78b9..911c7141e45d5f2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -890,6 +890,9 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   else if (func.getArmLocallyStreaming())
     llvmFunc->addFnAttr("aarch64_pstate_sm_body");
 
+  if (func.getArmNewZa())
+    llvmFunc->addFnAttr("aarch64_pstate_za_new");
+
   if (auto attr = func.getVscaleRange())
     llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
         getLLVMContext(), attr->getMinRange().getInt(),
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index d81f3c4b1e20c5a..03bc7eec580418a 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -139,6 +139,10 @@ if(MLIR_ENABLE_ROCM_RUNNER)
   list(APPEND MLIR_TEST_DEPENDS mlir_rocm_runtime)
 endif()
 
+if(MLIR_RUN_ARM_SME_TESTS)
+  list(APPEND MLIR_TEST_DEPENDS mlir_arm_sme_runtime)
+endif()
+
 list(APPEND MLIR_TEST_DEPENDS MLIRUnitTests)
 
 if(LLVM_BUILD_EXAMPLES)
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index e7bbe8c0047687d..2ec6f4090dff0c2 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -1,13 +1,13 @@
 // RUN: mlir-opt %s -enable-arm-streaming -verify-diagnostics | FileCheck %s
-// RUN: mlir-opt %s -enable-arm-streaming=mode=locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
-// RUN: mlir-opt %s -enable-arm-streaming=enable-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
 
 // CHECK-LABEL: @arm_streaming
 // CHECK-SAME: attributes {arm_streaming}
 // CHECK-LOCALLY-LABEL: @arm_streaming
 // CHECK-LOCALLY-SAME: attributes {arm_locally_streaming}
 // CHECK-ENABLE-ZA-LABEL: @arm_streaming
-// CHECK-ENABLE-ZA-SAME: attributes {arm_streaming, arm_za}
+// CHECK-ENABLE-ZA-SAME: attributes {arm_new_za, arm_streaming}
 func.func @arm_streaming() { return }
 
 // CHECK-LABEL: @not_arm_streaming
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
index d415b19f6fa94cf..8631721ef61bc77 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
@@ -1,18 +1,16 @@
-// RUN: mlir-opt %s -enable-arm-streaming=enable-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
 // RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA
 // RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING
 
 // CHECK-LABEL: @declaration
 func.func private @declaration()
 
-// CHECK-LABEL: @arm_za
-func.func @arm_za() {
-  // ENABLE-ZA: arm_sme.intr.za.enable
-  // ENABLE-ZA-NEXT: arm_sme.intr.za.disable
-  // ENABLE-ZA-NEXT: return
-  // DISABLE-ZA-NOT: arm_sme.intr.za.enable
-  // DISABLE-ZA-NOT: arm_sme.intr.za.disable
-  // NO-ARM-STREAMING-NOT: arm_sme.intr.za.enable
-  // NO-ARM-STREAMING-NOT: arm_sme.intr.za.disable
-  return
-}
+// ENABLE-ZA-LABEL: @arm_new_za
+// ENABLE-ZA-SAME: attributes {arm_new_za, arm_streaming}
+// DISABLE-ZA-LABEL: @arm_new_za
+// DISABLE-ZA-NOT: arm_new_za
+// DISABLE-ZA-SAME: attributes {arm_streaming}
+// NO-ARM-STREAMING-LABEL: @arm_new_za
+// NO-ARM-STREAMING-NOT: arm_new_za
+// NO-ARM-STREAMING-NOT: arm_streaming
+func.func @arm_new_za() { return }
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 131cbc05a9857e0..8755a2c5064ee82 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -3,14 +3,14 @@
 // RUN:   -test-transform-dialect-erase-schedule \
 // RUN:   -lower-vector-mask \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" \
-// RUN:   -enable-arm-streaming="mode=locally enable-za" \
+// RUN:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // RUN:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=entry -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \
-// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
+// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime | \
 // RUN: FileCheck %s
 
 func.func @entry() {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
index 1d6125a0d7999f5..276a095ecc3cad7 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
@@ -1,13 +1,13 @@
 // DEFINE: %{entry_point} = test_load_store_zaq0
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=void \
-// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index eda4d9a090f8d40..27af2f0a5daa618 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -1,13 +1,13 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
-// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index ae5ad9cc2a5e90c..07a026e14b68b36 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -1,13 +1,13 @@
 // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
-// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
 
 // RUN: %{compile}
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 36ce896a4c1bd90..81646b17e12dc48 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -1,13 +1,13 @@
 // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_2x2xf64
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme-f64f64 \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
-// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
 
 // RUN: %{compile}
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
index 65b930115e88895..a88350c14ff2da2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -1,13 +1,13 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
-// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index 92031586b8cfc91..ad55c4a4b11aa75 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -1,11 +1,11 @@
-// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
+// RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // RUN:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:  -march=aarch64 -mattr=+sve,+sme \
 // RUN:  -e entry -entry-point-result=i32 \
-// RUN:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
+// RUN:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime | \
 // RUN: FileCheck %s
 
 // Integration test demonstrating filling a 32-bit element ZA tile with a
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index adf1d365cb99823..056ee4b954c4f94 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -1,13 +1,13 @@
 // DEFINE: %{entry_point} = za0_d_f64
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=i32 \
-// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
 
 // RUN: %{compile} | %{run} | FileCheck %s --check-prefix=CHECK-ZA0_D
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index 455405d923bd664..a9a8060d4323083 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,12 +1,12 @@
 // DEFINE: %{entry_point} = entry
-// DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=i32 \
-// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 27c94d9aeac8bf4..aa0389e888b60d6 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -220,17 +220,6 @@ llvm.func @arm_sme_store(%nxv1i1  : vector<[1]xi1>,
 
 // -----
 
-// CHECK-LABEL: @arm_sme_toggle_za
-llvm.func @arm_sme_toggle_za() {
-  // CHECK: call void @llvm.aarch64.sme.za.enable()
-  "arm_sme.intr.za.enable"() : () -> ()
-  // CHECK: call void @llvm.aarch64.sme.za.disable()
-  "arm_sme.intr.za.disable"() : () -> ()
-  llvm.return
-}
-
-// -----
-
 // CHECK-LABEL: @arm_sme_vector_to_tile_horiz
 llvm.func @arm_sme_vector_to_tile_horiz(%tileslice : i32,
                                         %nxv16i1 : vector<[16]xi1>,
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index da8488373862c36..0512e6bce3a3d94 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -126,6 +126,9 @@ def add_runtime(name):
 if config.enable_cuda_runner:
     tools.extend([add_runtime("mlir_cuda_runtime")])
 
+if config.mlir_run_arm_sme_tests:
+    tools.extend([add_runtime("mlir_arm_sme_runtime")])
+
 # The following tools are optional
 tools.extend(
     [

>From deda3731a0fdfc88e40758c834d2f32aabc92cda Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 6 Nov 2023 11:22:18 +0000
Subject: [PATCH 2/6] Default to ABI stubs configurable via
 ARM_SME_ABI_ROUTINES_SHLIB

This removes the direct dependency on compiler-rt and instead includes
ABI stub routines in MLIR. Our current tests pass with only stubs, as
we're not making nested ZA-enabled calls. Using these stubs can be
overridden by setting the ARM_SME_ABI_ROUTINES_SHLIB CMake cache
variable to a path to an alternate implementation.
---
 mlir/lib/ExecutionEngine/ArmSMEStub.cpp       | 38 +++++++++++++++++++
 mlir/lib/ExecutionEngine/CMakeLists.txt       | 35 ++---------------
 mlir/test/CMakeLists.txt                      |  6 ++-
 .../Dialect/Linalg/CPU/ArmSME/fill-2d.mlir    |  2 +-
 .../CPU/ArmSME/load-store-128-bit-tile.mlir   |  2 +-
 .../Vector/CPU/ArmSME/test-load-vertical.mlir |  2 +-
 .../CPU/ArmSME/test-outerproduct-f32.mlir     |  2 +-
 .../CPU/ArmSME/test-outerproduct-f64.mlir     |  2 +-
 .../Vector/CPU/ArmSME/test-transpose.mlir     |  2 +-
 .../Dialect/Vector/CPU/ArmSME/tile_fill.mlir  |  2 +-
 .../Vector/CPU/ArmSME/vector-load-store.mlir  |  2 +-
 .../Dialect/Vector/CPU/ArmSME/vector-ops.mlir |  2 +-
 mlir/test/lit.cfg.py                          | 21 +++++++---
 mlir/test/lit.site.cfg.py.in                  |  1 +
 14 files changed, 71 insertions(+), 48 deletions(-)

diff --git a/mlir/lib/ExecutionEngine/ArmSMEStub.cpp b/mlir/lib/ExecutionEngine/ArmSMEStub.cpp
index af741ec716800ce..f9f64ad5e5ac81c 100644
--- a/mlir/lib/ExecutionEngine/ArmSMEStub.cpp
+++ b/mlir/lib/ExecutionEngine/ArmSMEStub.cpp
@@ -1,5 +1,21 @@
+//===- ArmSMEStub.cpp - ArmSME ABI routine stubs --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
 
 #include "llvm/Support/Compiler.h"
+#include <cstdint>
+#include <iostream>
+
+// The actual implementation of these routines is in:
+// compiler-rt/lib/builtins/aarch64/sme-abi.S. These stubs allow the current
+// ArmSME tests to run without depending on compiler-rt. This works as we don't
+// rely on nested ZA-enabled calls at the moment. The use of these stubs can be
+// overridden by setting the ARM_SME_ABI_ROUTINES_SHLIB CMake cache variable to
+// a path to an alternate implementation.
 
 extern "C" {
 
@@ -7,4 +23,26 @@ bool LLVM_ATTRIBUTE_WEAK __aarch64_sme_accessible() {
   // The ArmSME tests are run within an emulator so we assume SME is available.
   return true;
 }
+
+struct sme_state {
+  int64_t x0;
+  int64_t x1;
+};
+
+sme_state LLVM_ATTRIBUTE_WEAK __arm_sme_state() {
+  std::cerr << "[warning] __arm_sme_state() stubbed!\n";
+  return sme_state{};
+}
+
+void LLVM_ATTRIBUTE_WEAK __arm_tpidr2_restore() {
+  std::cerr << "[warning] __arm_tpidr2_restore() stubbed!\n";
+}
+
+void LLVM_ATTRIBUTE_WEAK __arm_tpidr2_save() {
+  std::cerr << "[warning] __arm_tpidr2_save() stubbed!\n";
+}
+
+void LLVM_ATTRIBUTE_WEAK __arm_za_disable() {
+  std::cerr << "[warning] __arm_za_disable() stubbed!\n";
+}
 }
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index 6da49b472177684..aa8cb9728e8cdd0 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -178,38 +178,9 @@ if(LLVM_ENABLE_PIC)
     target_link_options(mlir_async_runtime PRIVATE "-Wl,-exclude-libs,ALL")
   endif()
 
-  if (MLIR_RUN_ARM_SME_TESTS)
-    if (NOT DEFINED LLVM_MAIN_SRC_DIR)
-      message(FATAL_ERROR "LLVM_MAIN_SRC_DIR must be provided to build the ArmSME runtime.")
-    endif()
-
-    if (NOT DEFINED MLIR_ARM_SME__CAN_ASSEMBLE_ARM_SME)
-      # This should work on an AArch64 host with a recent version of clang.
-      file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/has_arm_sme_check.S
-      ".arch armv9-a+sme
-      .global main
-      .type   main, %function
-      main: smstart
-      .size   main, .-main")
-      try_compile(MLIR_ARM_SME__CAN_ASSEMBLE_ARM_SME ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/has_arm_sme_check.S)
-    endif()
-
-    if (NOT MLIR_ARM_SME__CAN_ASSEMBLE_ARM_SME)
-      message(FATAL_ERROR "Host compiler must be able to assemble AArch64 SME instructions to build the ArmSME runtime.")
-    endif()
-
-    # FIXME: This is very far from ideal, but enabling compiler-rt in the main
-    # build requires building much more than we need, and does not expose
-    # individual targets (e.g. to build a standalone runtime).
-    add_mlir_library(mlir_arm_sme_runtime
-      SHARED
-      ArmSMEStub.cpp
-      ${LLVM_MAIN_SRC_DIR}/../compiler-rt/lib/builtins/aarch64/sme-abi-init.c
-      ${LLVM_MAIN_SRC_DIR}/../compiler-rt/lib/builtins/aarch64/sme-abi.S
-      EXCLUDE_FROM_LIBMLIR)
-
-    target_compile_definitions(mlir_arm_sme_runtime PRIVATE -DCOMPILER_RT_SHARED_LIB)
-  endif()
+  add_mlir_library(mlir_arm_sme_abi_stubs
+    SHARED
+    ArmSMEStub.cpp)
 
   if(MLIR_ENABLE_CUDA_RUNNER)
     # Configure CUDA support. Using check_language first allows us to give a
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 03bc7eec580418a..c7b7debffc56ab9 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -28,6 +28,8 @@ if (MLIR_INCLUDE_INTEGRATION_TESTS)
       "If arch-specific Arm integration tests run emulated, find Arm native utility libraries in this directory.")
   set(MLIR_GPU_COMPILATION_TEST_FORMAT "fatbin" CACHE STRING
       "The GPU compilation format used by the tests.")
+  set(ARM_SME_ABI_ROUTINES_SHLIB "" CACHE STRING
+      "Path to a shared library containing Arm SME ABI routines, required for Arm SME integration tests.")
   option(MLIR_RUN_AMX_TESTS "Run AMX tests.")
   option(MLIR_RUN_X86VECTOR_TESTS "Run X86Vector tests.")
   option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.")
@@ -139,8 +141,8 @@ if(MLIR_ENABLE_ROCM_RUNNER)
   list(APPEND MLIR_TEST_DEPENDS mlir_rocm_runtime)
 endif()
 
-if(MLIR_RUN_ARM_SME_TESTS)
-  list(APPEND MLIR_TEST_DEPENDS mlir_arm_sme_runtime)
+if (MLIR_RUN_ARM_SME_TESTS)
+  list(APPEND MLIR_TEST_DEPENDS mlir_arm_sme_abi_stubs)
 endif()
 
 list(APPEND MLIR_TEST_DEPENDS MLIRUnitTests)
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 8755a2c5064ee82..1d9f3977389c850 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -10,7 +10,7 @@
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=entry -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \
-// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime | \
+// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \
 // RUN: FileCheck %s
 
 func.func @entry() {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
index 276a095ecc3cad7..ce2ef7e2b32e9c9 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
@@ -7,7 +7,7 @@
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=void \
-// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index 27af2f0a5daa618..4eab68f5de9345c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -7,7 +7,7 @@
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
-// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 07a026e14b68b36..5eba1054151a2ac 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -7,7 +7,7 @@
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
-// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib
 
 // RUN: %{compile}
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 81646b17e12dc48..a972dc1d2486ad6 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -7,7 +7,7 @@
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme-f64f64 \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
-// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib
 
 // RUN: %{compile}
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
index a88350c14ff2da2..5ca390768e5112a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -7,7 +7,7 @@
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
-// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index ad55c4a4b11aa75..a4a3f0958f3c0f5 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -5,7 +5,7 @@
 // RUN: %mcr_aarch64_cmd \
 // RUN:  -march=aarch64 -mattr=+sve,+sme \
 // RUN:  -e entry -entry-point-result=i32 \
-// RUN:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime | \
+// RUN:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \
 // RUN: FileCheck %s
 
 // Integration test demonstrating filling a 32-bit element ZA tile with a
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index 056ee4b954c4f94..46167bacf40892e 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -7,7 +7,7 @@
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=i32 \
-// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib
 
 // RUN: %{compile} | %{run} | FileCheck %s --check-prefix=CHECK-ZA0_D
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index a9a8060d4323083..4f1092232a0e1db 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -6,7 +6,7 @@
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=i32 \
-// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 0512e6bce3a3d94..87bbe51e95d4c9d 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -54,10 +54,9 @@
 config.substitutions.append(("%host_cc", config.host_cc))
 
 
-# Searches for a runtime library with the given name and returns a tool
-# substitution of the same name and the found path.
+# Searches for a runtime library with the given name and returns the found path.
 # Correctly handles the platforms shared library directory and naming conventions.
-def add_runtime(name):
+def find_runtime(name):
     path = ""
     for prefix in ["", "lib"]:
         path = os.path.join(
@@ -65,7 +64,13 @@ def add_runtime(name):
         )
         if os.path.isfile(path):
             break
-    return ToolSubst(f"%{name}", path)
+    return path
+
+
+# Searches for a runtime library with the given name and returns a tool
+# substitution of the same name and the found path.
+def add_runtime(name):
+    return ToolSubst(f"%{name}", find_runtime(name))
 
 
 llvm_config.with_system_environment(["HOME", "INCLUDE", "LIB", "TMP", "TEMP"])
@@ -127,7 +132,13 @@ def add_runtime(name):
     tools.extend([add_runtime("mlir_cuda_runtime")])
 
 if config.mlir_run_arm_sme_tests:
-    tools.extend([add_runtime("mlir_arm_sme_runtime")])
+    config.substitutions.append(
+        (
+            "%arm_sme_abi_shlib",
+            # Use passed Arm SME ABI routines, if not present default to stubs.
+            config.arm_sme_abi_routines_shlib or find_runtime("mlir_arm_sme_abi_stubs"),
+        )
+    )
 
 # The following tools are optional
 tools.extend(
diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in
index 2de40ba5e8e57e6..146e8443f5c98e5 100644
--- a/mlir/test/lit.site.cfg.py.in
+++ b/mlir/test/lit.site.cfg.py.in
@@ -56,6 +56,7 @@ config.arm_emulator_options = "@ARM_EMULATOR_OPTIONS@"
 config.arm_emulator_mlir_cpu_runner_executable = "@ARM_EMULATOR_MLIR_CPU_RUNNER_EXECUTABLE@"
 config.arm_emulator_lli_executable = "@ARM_EMULATOR_LLI_EXECUTABLE@"
 config.arm_emulator_utils_lib_dir = "@ARM_EMULATOR_UTILS_LIB_DIR@"
+config.arm_sme_abi_routines_shlib = "@ARM_SME_ABI_ROUTINES_SHLIB@"
 config.riscv_vector_emulator_executable = "@RISCV_VECTOR_EMULATOR_EXECUTABLE@"
 config.riscv_vector_emulator_options = "@RISCV_VECTOR_EMULATOR_OPTIONS@"
 config.riscv_emulator_lli_executable = "@RISCV_EMULATOR_LLI_EXECUTABLE@"

>From e4882a3c3ad177fc740e9fba752f0fec36d73353 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 6 Nov 2023 12:22:38 +0000
Subject: [PATCH 3/6] Fixup

---
 mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index e24487adc8a5bce..2ea5c6947754e65 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -38,7 +38,7 @@ def EnableArmStreaming
            "Select how ZA-storage is managed at the function-level.",
            [{::llvm::cl::values(
                 clEnumValN(mlir::arm_sme::ArmZaMode::Disabled, "disabled",
-					 	   "ZA storage is not enabled."),
+					 	   "ZA storage is disabled."),
                 clEnumValN(mlir::arm_sme::ArmZaMode::New, "new",
 					 	   "The function has ZA state. The ZA state is created on entry "
                "and destroyed on exit.")

>From f133e20478d8847469e5f82a317145db63ad9d8f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 8 Nov 2023 13:47:25 +0000
Subject: [PATCH 4/6] Fixups

- ArmSMEStub.cpp -> ArmSMEStubs.cpp
- Move enums to tablegen (to get generated stringification)
- Make enums closer to ACLE
  * Remove "Default" mode
  * Similar naming:
  * ArmZaMode::New -> ArmZaMode::NewZA
  * ArmStreamingMode::Locally -> ArmStreamingMode::StreamingLocally
---
 .../Dialect/ArmSME/Transforms/CMakeLists.txt  |  2 +
 .../mlir/Dialect/ArmSME/Transforms/Passes.h   | 17 ++----
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  | 57 +++++++++++++++----
 .../ArmSME/Transforms/EnableArmStreaming.cpp  | 33 +++++------
 .../{ArmSMEStub.cpp => ArmSMEStubs.cpp}       |  0
 mlir/lib/ExecutionEngine/CMakeLists.txt       |  4 +-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 17 ++++--
 mlir/test/CMakeLists.txt                      |  2 +-
 .../Dialect/ArmSME/enable-arm-streaming.mlir  |  4 +-
 mlir/test/Dialect/ArmSME/enable-arm-za.mlir   |  2 +-
 .../Dialect/Linalg/CPU/ArmSME/fill-2d.mlir    |  2 +-
 .../CPU/ArmSME/load-store-128-bit-tile.mlir   |  2 +-
 .../Vector/CPU/ArmSME/test-load-vertical.mlir |  2 +-
 .../CPU/ArmSME/test-outerproduct-f32.mlir     |  2 +-
 .../CPU/ArmSME/test-outerproduct-f64.mlir     |  2 +-
 .../Vector/CPU/ArmSME/test-transpose.mlir     |  2 +-
 .../Dialect/Vector/CPU/ArmSME/tile_fill.mlir  |  2 +-
 .../Vector/CPU/ArmSME/vector-load-store.mlir  |  2 +-
 .../Dialect/Vector/CPU/ArmSME/vector-ops.mlir |  2 +-
 19 files changed, 91 insertions(+), 65 deletions(-)
 rename mlir/lib/ExecutionEngine/{ArmSMEStub.cpp => ArmSMEStubs.cpp} (100%)

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/Transforms/CMakeLists.txt
index e2738b0fc404d63..38f48757b7749b7 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,5 +1,7 @@
 set(LLVM_TARGET_DEFINITIONS Passes.td)
 mlir_tablegen(Passes.h.inc -gen-pass-decls -name ArmSME)
+mlir_tablegen(PassesEnums.h.inc -gen-enum-decls)
+mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs)
 add_public_tablegen_target(MLIRArmSMETransformsIncGen)
 
 add_mlir_doc(Passes ArmSMEPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 95b016e87921a67..6f7617f5411c57f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H
 
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.h.inc"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -20,23 +21,13 @@ namespace arm_sme {
 //===----------------------------------------------------------------------===//
 // The EnableArmStreaming pass.
 //===----------------------------------------------------------------------===//
-// Options for Armv9 Streaming SVE mode. By default, streaming-mode is part of
-// the function interface (ABI) and the caller manages PSTATE.SM on entry/exit.
-// In a locally streaming function PSTATE.SM is kept internal and the callee
-// manages it on entry/exit.
-enum class ArmStreamingMode { Default = 0, Locally = 1 };
-
-// TODO: Add other ZA modes.
-// https://arm-software.github.io/acle/main/acle.html#sme-attributes-relating-to-za
-enum class ArmZaMode { Disabled = 0, New = 1 };
-
 #define GEN_PASS_DECL
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
 
 /// Pass to enable Armv9 Streaming SVE mode.
-std::unique_ptr<Pass>
-createEnableArmStreamingPass(const ArmStreamingMode = ArmStreamingMode::Default,
-                             const ArmZaMode = ArmZaMode::Disabled);
+std::unique_ptr<Pass> createEnableArmStreamingPass(
+    const ArmStreamingMode = ArmStreamingMode::Streaming,
+    const ArmZaMode = ArmZaMode::Disabled);
 
 /// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
 std::unique_ptr<Pass> createTileAllocationPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 2ea5c6947754e65..42aa397d160f5be 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -10,6 +10,34 @@
 #define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
 
 include "mlir/Pass/PassBase.td"
+include "mlir/IR/EnumAttr.td"
+
+// Options for Armv9 Streaming SVE mode.
+def ArmStreamingMode : I32EnumAttr<"ArmStreamingMode", "Tosa profile",
+    [
+      I32EnumAttrCase<"Disabled", 0, "disabled">,
+      // Streaming: Streaming-mode is part of the function interface (ABI).
+      I32EnumAttrCase<"Streaming", 1, "arm_streaming">,
+      // StreamingLocally: PSTATE.SM is kept internal and the callee manages it
+      // on entry/exit.
+      I32EnumAttrCase<"StreamingLocally", 2, "arm_locally_streaming">,
+    ]>{
+  let cppNamespace = "mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+// Options for Armv9 ZA storage mode.
+// TODO: Add other ZA modes.
+// https://arm-software.github.io/acle/main/acle.html#sme-attributes-relating-to-za
+def ArmZaMode : I32EnumAttr<"ArmZaMode", "Tosa level",
+    [
+      I32EnumAttrCase<"Disabled", 0, "disabled">,
+      // A function's ZA state is created on entry and destroyed on exit.
+      I32EnumAttrCase<"NewZA", 1, "arm_new_za">,
+    ]>{
+  let cppNamespace = "mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
 
 def EnableArmStreaming
     : Pass<"enable-arm-streaming", "mlir::func::FuncOp"> {
@@ -23,25 +51,30 @@ def EnableArmStreaming
   let constructor = "mlir::arm_sme::createEnableArmStreamingPass()";
   let options = [
     Option<"streamingMode", "streaming-mode", "mlir::arm_sme::ArmStreamingMode",
-          /*default=*/"mlir::arm_sme::ArmStreamingMode::Default",
+          /*default=*/"mlir::arm_sme::ArmStreamingMode::Streaming",
           "Select how streaming-mode is managed at the function-level.",
           [{::llvm::cl::values(
-                clEnumValN(mlir::arm_sme::ArmStreamingMode::Default, "default",
-						   "Streaming mode is part of the function interface "
-						   "(ABI), caller manages PSTATE.SM on entry/exit."),
-                clEnumValN(mlir::arm_sme::ArmStreamingMode::Locally, "locally",
-						   "Streaming mode is internal to the function, callee "
-						   "manages PSTATE.SM on entry/exit.")
+                clEnumValN(mlir::arm_sme::ArmStreamingMode::Disabled,
+                           "disabled", "Streaming mode is disabled."),
+                clEnumValN(mlir::arm_sme::ArmStreamingMode::Streaming,
+                           "streaming",
+                           "Streaming mode is part of the function interface "
+                           "(ABI), caller manages PSTATE.SM on entry/exit."),
+                clEnumValN(mlir::arm_sme::ArmStreamingMode::StreamingLocally,
+                           "streaming-locally",
+                           "Streaming mode is internal to the function, callee "
+                           "manages PSTATE.SM on entry/exit.")
           )}]>,
     Option<"zaMode", "za-mode", "mlir::arm_sme::ArmZaMode",
            /*default=*/"mlir::arm_sme::ArmZaMode::Disabled",
            "Select how ZA-storage is managed at the function-level.",
            [{::llvm::cl::values(
-                clEnumValN(mlir::arm_sme::ArmZaMode::Disabled, "disabled",
-					 	   "ZA storage is disabled."),
-                clEnumValN(mlir::arm_sme::ArmZaMode::New, "new",
-					 	   "The function has ZA state. The ZA state is created on entry "
-               "and destroyed on exit.")
+                 clEnumValN(mlir::arm_sme::ArmZaMode::Disabled,
+                            "disabled", "ZA storage is disabled."),
+                 clEnumValN(mlir::arm_sme::ArmZaMode::NewZA,
+                            "new-za",
+                            "The function has ZA state. The ZA state is "
+                            "created on entry and destroyed on exit.")
            )}]>
   ];
   let dependentDialects = ["func::FuncDialect"];
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index 1b59b6d907235b4..0d5367ff30bb959 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -34,6 +34,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 
@@ -48,14 +49,11 @@ namespace arm_sme {
 
 using namespace mlir;
 using namespace mlir::arm_sme;
+namespace {
 
-static constexpr char kArmStreamingAttr[] = "arm_streaming";
-static constexpr char kArmLocallyStreamingAttr[] = "arm_locally_streaming";
-static constexpr char kArmNewZAAttr[] = "arm_new_za";
-static constexpr char kEnableArmStreamingIgnoreAttr[] =
-    "enable_arm_streaming_ignore";
+constexpr StringLiteral
+    kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
 
-namespace {
 struct EnableArmStreamingPass
     : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
   EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode) {
@@ -63,25 +61,22 @@ struct EnableArmStreamingPass
     this->zaMode = zaMode;
   }
   void runOnOperation() override {
-    if (getOperation()->getAttr(kEnableArmStreamingIgnoreAttr))
+    auto op = getOperation();
+    if (op->getAttr(kEnableArmStreamingIgnoreAttr))
       return;
-    StringRef attr;
-    switch (streamingMode) {
-    case ArmStreamingMode::Default:
-      attr = kArmStreamingAttr;
-      break;
-    case ArmStreamingMode::Locally:
-      attr = kArmLocallyStreamingAttr;
-      break;
-    }
-    getOperation()->setAttr(attr, UnitAttr::get(&getContext()));
+    auto unitAttr = UnitAttr::get(&getContext());
+
+    if (streamingMode == ArmStreamingMode::Disabled)
+      return;
+
+    op->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
 
     // The pass currently only supports enabling ZA when in streaming-mode, but
     // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
     // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
     // supporting this later.
-    if (zaMode == ArmZaMode::New)
-      getOperation()->setAttr(kArmNewZAAttr, UnitAttr::get(&getContext()));
+    if (zaMode != ArmZaMode::Disabled)
+      op->setAttr(stringifyArmZaMode(zaMode), unitAttr);
   }
 };
 } // namespace
diff --git a/mlir/lib/ExecutionEngine/ArmSMEStub.cpp b/mlir/lib/ExecutionEngine/ArmSMEStubs.cpp
similarity index 100%
rename from mlir/lib/ExecutionEngine/ArmSMEStub.cpp
rename to mlir/lib/ExecutionEngine/ArmSMEStubs.cpp
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index aa8cb9728e8cdd0..fe139661f2bbb5a 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -2,7 +2,7 @@
 # is a big dependency which most don't need.
 
 set(LLVM_OPTIONAL_SOURCES
-  ArmSMEStub.cpp
+  ArmSMEStubs.cpp
   AsyncRuntime.cpp
   CRunnerUtils.cpp
   CudaRuntimeWrappers.cpp
@@ -180,7 +180,7 @@ if(LLVM_ENABLE_PIC)
 
   add_mlir_library(mlir_arm_sme_abi_stubs
     SHARED
-    ArmSMEStub.cpp)
+    ArmSMEStubs.cpp)
 
   if(MLIR_ENABLE_CUDA_RUNNER)
     # Configure CUDA support. Using check_language first allows us to give a
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 07074c704e08be8..496c27e83e4d26c 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1554,6 +1554,15 @@ static void processMemoryEffects(llvm::Function *func, LLVMFuncOp funcOp) {
   funcOp.setMemoryAttr(memAttr);
 }
 
+// List of LLVM IR attributes that map to an explicit attribute on the MLIR
+// LLVMFuncOp.
+static constexpr std::array ExplicitAttributes{
+    StringLiteral("aarch64_pstate_sm_enabled"),
+    StringLiteral("aarch64_pstate_sm_body"),
+    StringLiteral("aarch64_pstate_za_new"),
+    StringLiteral("vscale_range"),
+};
+
 static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
   MLIRContext *context = funcOp.getContext();
   SmallVector<Attribute> passthroughs;
@@ -1579,12 +1588,8 @@ static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
       attrName = llvm::Attribute::getNameFromAttrKind(attr.getKindAsEnum());
     auto keyAttr = StringAttr::get(context, attrName);
 
-    // Skip the aarch64_pstate_sm_<body|enabled> since the LLVMFuncOp has an
-    // explicit attribute.
-    // Also skip the vscale_range, it is also an explicit attribute.
-    if (attrName == "aarch64_pstate_sm_enabled" ||
-        attrName == "aarch64_pstate_sm_body" ||
-        attrName == "aarch64_pstate_za_new" || attrName == "vscale_range")
+    // Skip attributes that map to an explicit attribute on the LLVMFuncOp.
+    if (llvm::is_contained(ExplicitAttributes, attrName))
       continue;
 
     if (attr.isStringAttribute()) {
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index c7b7debffc56ab9..e4343095578c1f0 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -141,7 +141,7 @@ if(MLIR_ENABLE_ROCM_RUNNER)
   list(APPEND MLIR_TEST_DEPENDS mlir_rocm_runtime)
 endif()
 
-if (MLIR_RUN_ARM_SME_TESTS)
+if (MLIR_RUN_ARM_SME_TESTS AND NOT ARM_SME_ABI_ROUTINES_SHLIB)
   list(APPEND MLIR_TEST_DEPENDS mlir_arm_sme_abi_stubs)
 endif()
 
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 2ec6f4090dff0c2..f221bff2e6f273f 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s -enable-arm-streaming -verify-diagnostics | FileCheck %s
-// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally  -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
 
 // CHECK-LABEL: @arm_streaming
 // CHECK-SAME: attributes {arm_streaming}
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
index 8631721ef61bc77..0f31278eefd1550 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
 // RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA
 // RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING
 
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 1d9f3977389c850..efe4da7d3c50c6f 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -3,7 +3,7 @@
 // RUN:   -test-transform-dialect-erase-schedule \
 // RUN:   -lower-vector-mask \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" \
-// RUN:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
+// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // RUN:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
index ce2ef7e2b32e9c9..32e7e6b79ce09b9 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
@@ -1,6 +1,6 @@
 // DEFINE: %{entry_point} = test_load_store_zaq0
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index 4eab68f5de9345c..44cf23f41b63254 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -1,6 +1,6 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 5eba1054151a2ac..f1ecf768ebe83db 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -1,6 +1,6 @@
 // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index a972dc1d2486ad6..5c907bb1675e462 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -1,6 +1,6 @@
 // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_2x2xf64
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
index 5ca390768e5112a..39b5ef2ade4b0c0 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -1,6 +1,6 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index a4a3f0958f3c0f5..baf2046722b9e0c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=locally za-mode=new" \
+// RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // RUN:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index 46167bacf40892e..8878dca8bdcb6b1 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -1,6 +1,6 @@
 // DEFINE: %{entry_point} = za0_d_f64
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index 4f1092232a0e1db..a890aaa6f309d15 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,5 +1,5 @@
 // DEFINE: %{entry_point} = entry
-// DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=locally za-mode=new" \
+// DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm

>From 1771013bc542f4d9b851353744e5f34c71f18fe4 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 8 Nov 2023 13:57:30 +0000
Subject: [PATCH 5/6] Fix copy/paste error :)

---
 mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 42aa397d160f5be..3253b47e62abddb 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -12,8 +12,7 @@
 include "mlir/Pass/PassBase.td"
 include "mlir/IR/EnumAttr.td"
 
-// Options for Armv9 Streaming SVE mode.
-def ArmStreamingMode : I32EnumAttr<"ArmStreamingMode", "Tosa profile",
+def ArmStreamingMode : I32EnumAttr<"ArmStreamingMode", "Armv9 Streaming SVE mode",
     [
       I32EnumAttrCase<"Disabled", 0, "disabled">,
       // Streaming: Streaming-mode is part of the function interface (ABI).
@@ -26,10 +25,9 @@ def ArmStreamingMode : I32EnumAttr<"ArmStreamingMode", "Tosa profile",
   let genSpecializedAttr = 0;
 }
 
-// Options for Armv9 ZA storage mode.
 // TODO: Add other ZA modes.
 // https://arm-software.github.io/acle/main/acle.html#sme-attributes-relating-to-za
-def ArmZaMode : I32EnumAttr<"ArmZaMode", "Tosa level",
+def ArmZaMode : I32EnumAttr<"ArmZaMode", "Armv9 ZA storage mode",
     [
       I32EnumAttrCase<"Disabled", 0, "disabled">,
       // A function's ZA state is created on entry and destroyed on exit.

>From 84c9d7fcaff53ca08125d75f5f196c15d69ee751 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 10 Nov 2023 12:13:36 +0000
Subject: [PATCH 6/6] Update recently added tests

These required a few little changes to avoid nested ZA enabled calls.
---
 .../Linalg/CPU/ArmSME/matmul-transpose-a.mlir    |  6 +++---
 .../Vector/CPU/ArmSME/test-transfer-read-2d.mlir | 16 +++++++++++-----
 .../CPU/ArmSME/test-transfer-write-2d.mlir       | 16 +++++++++++-----
 3 files changed, 25 insertions(+), 13 deletions(-)

diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
index 28179fed31eca4b..ab74f0100474263 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s \
 // RUN:   -transform-interpreter -test-transform-dialect-erase-schedule \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
-// RUN:   -enable-arm-streaming="mode=locally enable-za" \
+// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // RUN:   -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
 // RUN:   -convert-vector-to-llvm=enable-arm-sme \
@@ -10,7 +10,7 @@
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=main -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \
-// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
+// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \
 // RUN: FileCheck %s
 
 func.func @matmul_transpose_a(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : tensor<?x?xf32>) {
@@ -21,7 +21,7 @@ func.func @matmul_transpose_a(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : t
   return
 }
 
-func.func @main() {
+func.func @main() attributes { enable_arm_streaming_ignore } {
   %c0 = arith.constant 0 : i32
   %c7 = arith.constant 7 : index
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
index 48725d9ea03f94c..ccc08289570afc5 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -1,13 +1,13 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=void \
-// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
@@ -134,7 +134,13 @@ func.func @initialize_memory(%d0 : index, %d1 : index) -> memref<?x?xf32> {
   return %A : memref<?x?xf32>
 }
 
-func.func @entry() {
+// This will be made a streaming function by enable-arm-streaming so return SVL.
+func.func @get_svl() -> index {
+  %vscale = vector.vscale
+  return %vscale : index
+}
+
+func.func @entry() attributes { enable_arm_streaming_ignore } {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
@@ -142,8 +148,8 @@ func.func @entry() {
 
   // Allocate enough memory to load a 32-bit tile plus a tiny bit more to test
   // non-zero offsets while remaining inbounds.
-  %vscale = vector.vscale
-  %svl_s = arith.muli %c4, %vscale : index
+  %svl = call @get_svl() : () -> index
+  %svl_s = arith.muli %c4, %svl : index
   %svl_s_plus_two = arith.addi %svl_s, %c2 : index
 
   %A = call @initialize_memory(%svl_s_plus_two, %svl_s_plus_two) : (index, index) -> memref<?x?xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index 49c513badb7b071..f35f83dcec0daa2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -1,13 +1,13 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=void \
-// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
@@ -96,7 +96,13 @@ func.func @initialize_memory(%d0 : index, %d1 : index) -> memref<?x?xf32> {
   return %A : memref<?x?xf32>
 }
 
-func.func @entry() {
+// This will be made a streaming function by enable-arm-streaming so return SVL.
+func.func @get_svl() -> index {
+  %vscale = vector.vscale
+  return %vscale : index
+}
+
+func.func @entry() attributes { enable_arm_streaming_ignore } {
   %c0 = arith.constant 0 : index
   %c2 = arith.constant 2 : index
   %c4 = arith.constant 4 : index
@@ -105,8 +111,8 @@ func.func @entry() {
   //
   // Allocate enough memory to load a 32-bit tile plus a tiny bit more to test
   // non-zero offsets while remaining inbounds.
-  %vscale = vector.vscale
-  %svl_s = arith.muli %c4, %vscale : index
+  %svl = call @get_svl() : () -> index
+  %svl_s = arith.muli %c4, %svl : index
   %svl_s_plus_two = arith.addi %svl_s, %c2 : index
   %A = call @initialize_memory(%svl_s_plus_two, %svl_s_plus_two) : (index, index) -> memref<?x?xf32>
 



More information about the Mlir-commits mailing list