[Mlir-commits] [mlir] [mlir][ArmSME] Make use of backend function attributes for enabling ZA storage (PR #71044)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Nov 2 03:47:53 PDT 2023


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/71044

Previously, we were inserting za.enable/disable intrinsics for functions with the "arm_za" attribute (at the MLIR level), rather than using the backend attributes. This was done to avoid a dependency on the SME ABI functions from compiler-rt (which have only recently been implemented).

Doing things this way did have correctness issues, for example, calling a streaming-mode function from another streaming-mode function (both with ZA enabled) would lead to ZA being disabled after returning to the caller (where it should still be enabled). Fixing issues like this would require re-doing the ABI work already done in the backend within MLIR.

Instead, this patch switches to use the "arm_new_za" (backend) attribute for enabling ZA for a MLIR function. For the integration tests this requires some way of linking the SME ABI functions. This has been done by adding a mlir_arm_sme_runtime library, which includes the implementation from compiler-rt, which can then be linked via the `-shared-libs` flag.

To build the mlir_arm_sme_runtime the target has to be AArch64, and the host compiler must be able to assemble SME instructions (this is supported in recent versions of clang). Note that the host being AArch64 is already assumed by the integration tests linking other runtime libraries (e.g. mlir_c_runner_utils).

>From a43cf44c4765cbc42d91073489d943484215f4b2 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 1 Nov 2023 15:57:18 +0000
Subject: [PATCH] [mlir][ArmSME] Make use of backend function attributes for
 enabling ZA storage

Previously, we were inserting za.enable/disable intrinsics for functions
with the "arm_za" attribute (at the MLIR level), rather than using the
backend attributes. This was done to avoid a dependency on the SME ABI
functions from compiler-rt (which have only recently been implemented).

Doing things this way did have correctness issues, for example, calling
a streaming-mode function from another streaming-mode function (both
with ZA enabled) would lead to ZA being disabled after returning to the
caller (where it should still be enabled). Fixing issues like this would
require re-doing the ABI work already done in the backend within MLIR.

Instead, this patch switches to use the "arm_new_za" (backend) attribute
for enabling ZA for a MLIR function. For the integration tests this
requires some way of linking the SME ABI functions. This has been done
by adding a mlir_arm_sme_runtime library, which includes the
implementation from compiler-rt, which can then be linked via the
`-shared-libs` flag.

To build the mlir_arm_sme_runtime the target has to be AArch64, and
the host compiler must be able to assemble SME instructions (this is
supported in recent versions of clang). Note that the host being AArch64
is already assumed by the integration tests linking other runtime
libraries (e.g. mlir_c_runner_utils).
---
 .../Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td   |  3 -
 .../mlir/Dialect/ArmSME/Transforms/Passes.h   | 10 +++-
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  | 20 +++++--
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |  1 +
 .../ArmSME/Transforms/EnableArmStreaming.cpp  | 25 ++++-----
 .../Transforms/LegalizeForLLVMExport.cpp      | 55 +------------------
 mlir/lib/ExecutionEngine/ArmSMEStub.cpp       | 10 ++++
 mlir/lib/ExecutionEngine/CMakeLists.txt       | 34 ++++++++++++
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       |  7 ++-
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  3 +
 mlir/test/CMakeLists.txt                      |  4 ++
 .../Dialect/ArmSME/enable-arm-streaming.mlir  |  6 +-
 mlir/test/Dialect/ArmSME/enable-arm-za.mlir   | 22 ++++----
 .../Dialect/Linalg/CPU/ArmSME/fill-2d.mlir    |  4 +-
 .../CPU/ArmSME/load-store-128-bit-tile.mlir   |  4 +-
 .../Vector/CPU/ArmSME/test-load-vertical.mlir |  4 +-
 .../CPU/ArmSME/test-outerproduct-f32.mlir     |  4 +-
 .../CPU/ArmSME/test-outerproduct-f64.mlir     |  4 +-
 .../Vector/CPU/ArmSME/test-transpose.mlir     |  4 +-
 .../Dialect/Vector/CPU/ArmSME/tile_fill.mlir  |  4 +-
 .../Vector/CPU/ArmSME/vector-load-store.mlir  |  4 +-
 .../Dialect/Vector/CPU/ArmSME/vector-ops.mlir |  4 +-
 mlir/test/Target/LLVMIR/arm-sme.mlir          | 11 ----
 mlir/test/lit.cfg.py                          |  3 +
 24 files changed, 126 insertions(+), 124 deletions(-)
 create mode 100644 mlir/lib/ExecutionEngine/ArmSMEStub.cpp

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index e369ef203ad39d6..9f4ef24366b09db 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -131,7 +131,4 @@ def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
 def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
 def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
 
-def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
-def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
-
 #endif // ARMSME_INTRINSIC_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index ab5c179f2dd7790..95b016e87921a67 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -24,15 +24,19 @@ namespace arm_sme {
 // the function interface (ABI) and the caller manages PSTATE.SM on entry/exit.
 // In a locally streaming function PSTATE.SM is kept internal and the callee
 // manages it on entry/exit.
-enum class ArmStreaming { Default = 0, Locally = 1 };
+enum class ArmStreamingMode { Default = 0, Locally = 1 };
+
+// TODO: Add other ZA modes.
+// https://arm-software.github.io/acle/main/acle.html#sme-attributes-relating-to-za
+enum class ArmZaMode { Disabled = 0, New = 1 };
 
 #define GEN_PASS_DECL
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
 
 /// Pass to enable Armv9 Streaming SVE mode.
 std::unique_ptr<Pass>
-createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default,
-                             const bool enableZA = false);
+createEnableArmStreamingPass(const ArmStreamingMode = ArmStreamingMode::Default,
+                             const ArmZaMode = ArmZaMode::Disabled);
 
 /// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
 std::unique_ptr<Pass> createTileAllocationPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 3fa1b43eb9e67e0..e24487adc8a5bce 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -22,19 +22,27 @@ def EnableArmStreaming
   }];
   let constructor = "mlir::arm_sme::createEnableArmStreamingPass()";
   let options = [
-    Option<"mode", "mode", "mlir::arm_sme::ArmStreaming",
-          /*default=*/"mlir::arm_sme::ArmStreaming::Default",
+    Option<"streamingMode", "streaming-mode", "mlir::arm_sme::ArmStreamingMode",
+          /*default=*/"mlir::arm_sme::ArmStreamingMode::Default",
           "Select how streaming-mode is managed at the function-level.",
           [{::llvm::cl::values(
-                clEnumValN(mlir::arm_sme::ArmStreaming::Default, "default",
+                clEnumValN(mlir::arm_sme::ArmStreamingMode::Default, "default",
 						   "Streaming mode is part of the function interface "
 						   "(ABI), caller manages PSTATE.SM on entry/exit."),
-                clEnumValN(mlir::arm_sme::ArmStreaming::Locally, "locally",
+                clEnumValN(mlir::arm_sme::ArmStreamingMode::Locally, "locally",
 						   "Streaming mode is internal to the function, callee "
 						   "manages PSTATE.SM on entry/exit.")
           )}]>,
-    Option<"enableZA", "enable-za", "bool", /*default=*/"false",
-           "Enable ZA storage array.">,
+    Option<"zaMode", "za-mode", "mlir::arm_sme::ArmZaMode",
+           /*default=*/"mlir::arm_sme::ArmZaMode::Disabled",
+           "Select how ZA-storage is managed at the function-level.",
+           [{::llvm::cl::values(
+                clEnumValN(mlir::arm_sme::ArmZaMode::Disabled, "disabled",
+					 	   "ZA storage is not enabled."),
+                clEnumValN(mlir::arm_sme::ArmZaMode::New, "new",
+					 	   "The function has ZA state. The ZA state is created on entry "
+               "and destroyed on exit.")
+           )}]>
   ];
   let dependentDialects = ["func::FuncDialect"];
 }
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 638c31b39682ea6..dfc0588e92e44ed 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1415,6 +1415,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     DefaultValuedAttr<Visibility, "mlir::LLVM::Visibility::Default">:$visibility_,
     OptionalAttr<UnitAttr>:$arm_streaming,
     OptionalAttr<UnitAttr>:$arm_locally_streaming,
+    OptionalAttr<UnitAttr>:$arm_new_za,
     OptionalAttr<StrAttr>:$section,
     OptionalAttr<UnnamedAddr>:$unnamed_addr,
     OptionalAttr<I64Attr>:$alignment,
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index 1d3a090e861013b..1b59b6d907235b4 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -51,26 +51,26 @@ using namespace mlir::arm_sme;
 
 static constexpr char kArmStreamingAttr[] = "arm_streaming";
 static constexpr char kArmLocallyStreamingAttr[] = "arm_locally_streaming";
-static constexpr char kArmZAAttr[] = "arm_za";
+static constexpr char kArmNewZAAttr[] = "arm_new_za";
 static constexpr char kEnableArmStreamingIgnoreAttr[] =
     "enable_arm_streaming_ignore";
 
 namespace {
 struct EnableArmStreamingPass
     : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
-  EnableArmStreamingPass(ArmStreaming mode, bool enableZA) {
-    this->mode = mode;
-    this->enableZA = enableZA;
+  EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode) {
+    this->streamingMode = streamingMode;
+    this->zaMode = zaMode;
   }
   void runOnOperation() override {
     if (getOperation()->getAttr(kEnableArmStreamingIgnoreAttr))
       return;
     StringRef attr;
-    switch (mode) {
-    case ArmStreaming::Default:
+    switch (streamingMode) {
+    case ArmStreamingMode::Default:
       attr = kArmStreamingAttr;
       break;
-    case ArmStreaming::Locally:
+    case ArmStreamingMode::Locally:
       attr = kArmLocallyStreamingAttr;
       break;
     }
@@ -80,14 +80,13 @@ struct EnableArmStreamingPass
     // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
     // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
     // supporting this later.
-    if (enableZA)
-      getOperation()->setAttr(kArmZAAttr, UnitAttr::get(&getContext()));
+    if (zaMode == ArmZaMode::New)
+      getOperation()->setAttr(kArmNewZAAttr, UnitAttr::get(&getContext()));
   }
 };
 } // namespace
 
-std::unique_ptr<Pass>
-mlir::arm_sme::createEnableArmStreamingPass(const ArmStreaming mode,
-                                            const bool enableZA) {
-  return std::make_unique<EnableArmStreamingPass>(mode, enableZA);
+std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
+    const ArmStreamingMode streamingMode, const ArmZaMode zaMode) {
+  return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode);
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index d1a54658a595bf3..6078b3f2c5e4708 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -21,33 +21,6 @@ using namespace mlir;
 using namespace mlir::arm_sme;
 
 namespace {
-/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
-/// ops to enable the ZA storage array.
-struct EnableZAPattern : public OpRewritePattern<func::FuncOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(func::FuncOp op,
-                                PatternRewriter &rewriter) const final {
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPointToStart(&op.front());
-    rewriter.create<arm_sme::aarch64_sme_za_enable>(op->getLoc());
-    rewriter.updateRootInPlace(op, [] {});
-    return success();
-  }
-};
-
-/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to
-/// disable the ZA storage array.
-struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(func::ReturnOp op,
-                                PatternRewriter &rewriter) const final {
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(op);
-    rewriter.create<arm_sme::aarch64_sme_za_disable>(op->getLoc());
-    rewriter.updateRootInPlace(op, [] {});
-    return success();
-  }
-};
 
 /// Lower 'arm_sme.zero' to SME intrinsics.
 ///
@@ -678,39 +651,13 @@ void mlir::configureArmSMELegalizeForExportTarget(
       arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
       arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
       arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
-      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
-      arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
+      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
   target.addLegalOp<GetTileID>();
   target.addIllegalOp<vector::OuterProductOp>();
-
-  // Mark 'func.func' ops as legal if either:
-  //   1. no 'arm_za' function attribute is present.
-  //   2. the 'arm_za' function attribute is present and the first op in the
-  //      function is an 'arm_sme::aarch64_sme_za_enable' intrinsic.
-  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
-    if (funcOp.isDeclaration())
-      return true;
-    auto firstOp = funcOp.getBody().front().begin();
-    return !funcOp->hasAttr("arm_za") ||
-           isa<arm_sme::aarch64_sme_za_enable>(firstOp);
-  });
-
-  // Mark 'func.return' ops as legal if either:
-  //   1. no 'arm_za' function attribute is present.
-  //   2. the 'arm_za' function attribute is present and there's a preceding
-  //      'arm_sme::aarch64_sme_za_disable' intrinsic.
-  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp returnOp) {
-    bool hasDisableZA = false;
-    auto funcOp = returnOp->getParentOp();
-    funcOp->walk<WalkOrder::PreOrder>(
-        [&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; });
-    return !funcOp->hasAttr("arm_za") || hasDisableZA;
-  });
 }
 
 void mlir::populateArmSMELegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<DisableZAPattern, EnableZAPattern>(patterns.getContext());
   patterns.add<
       LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
       MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
diff --git a/mlir/lib/ExecutionEngine/ArmSMEStub.cpp b/mlir/lib/ExecutionEngine/ArmSMEStub.cpp
new file mode 100644
index 000000000000000..af741ec716800ce
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/ArmSMEStub.cpp
@@ -0,0 +1,10 @@
+
+#include "llvm/Support/Compiler.h"
+
+extern "C" {
+
+bool LLVM_ATTRIBUTE_WEAK __aarch64_sme_accessible() {
+  // The ArmSME tests are run within an emulator so we assume SME is available.
+  return true;
+}
+}
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index fdc797763ae3a41..6da49b472177684 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -2,6 +2,7 @@
 # is a big dependency which most don't need.
 
 set(LLVM_OPTIONAL_SOURCES
+  ArmSMEStub.cpp
   AsyncRuntime.cpp
   CRunnerUtils.cpp
   CudaRuntimeWrappers.cpp
@@ -177,6 +178,39 @@ if(LLVM_ENABLE_PIC)
     target_link_options(mlir_async_runtime PRIVATE "-Wl,-exclude-libs,ALL")
   endif()
 
+  if (MLIR_RUN_ARM_SME_TESTS)
+    if (NOT DEFINED LLVM_MAIN_SRC_DIR)
+      message(FATAL_ERROR "LLVM_MAIN_SRC_DIR must be provided to build the ArmSME runtime.")
+    endif()
+
+    if (NOT DEFINED MLIR_ARM_SME__CAN_ASSEMBLE_ARM_SME)
+      # This should work on an AArch64 host with a recent version of clang.
+      file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/has_arm_sme_check.S
+      ".arch armv9-a+sme
+      .global main
+      .type   main, %function
+      main: smstart
+      .size   main, .-main")
+      try_compile(MLIR_ARM_SME__CAN_ASSEMBLE_ARM_SME ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/has_arm_sme_check.S)
+    endif()
+
+    if (NOT MLIR_ARM_SME__CAN_ASSEMBLE_ARM_SME)
+      message(FATAL_ERROR "Host compiler must be able to assemble AArch64 SME instructions to build the ArmSME runtime.")
+    endif()
+
+    # FIXME: This is very far from ideal, but enabling compiler-rt in the main
+    # build requires building much more than we need, and does not expose
+    # individual targets (e.g. to build a standalone runtime).
+    add_mlir_library(mlir_arm_sme_runtime
+      SHARED
+      ArmSMEStub.cpp
+      ${LLVM_MAIN_SRC_DIR}/../compiler-rt/lib/builtins/aarch64/sme-abi-init.c
+      ${LLVM_MAIN_SRC_DIR}/../compiler-rt/lib/builtins/aarch64/sme-abi.S
+      EXCLUDE_FROM_LIBMLIR)
+
+    target_compile_definitions(mlir_arm_sme_runtime PRIVATE -DCOMPILER_RT_SHARED_LIB)
+  endif()
+
   if(MLIR_ENABLE_CUDA_RUNNER)
     # Configure CUDA support. Using check_language first allows us to give a
     # custom error message.
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index e3562049cd81c76..b4c56f995234cb3 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1583,7 +1583,8 @@ static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
     // explicit attribute.
     // Also skip the vscale_range, it is also an explicit attribute.
     if (attrName == "aarch64_pstate_sm_enabled" ||
-        attrName == "aarch64_pstate_sm_body" || attrName == "vscale_range")
+        attrName == "aarch64_pstate_sm_body" ||
+        attrName == "aarch64_pstate_za_new" || attrName == "vscale_range")
       continue;
 
     if (attr.isStringAttribute()) {
@@ -1623,6 +1624,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
     funcOp.setArmStreaming(true);
   else if (func->hasFnAttribute("aarch64_pstate_sm_body"))
     funcOp.setArmLocallyStreaming(true);
+
+  if (func->hasFnAttribute("aarch64_pstate_za_new"))
+    funcOp.setArmNewZa(true);
+
   llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange);
   if (attr.isValid()) {
     MLIRContext *context = funcOp.getContext();
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 7312388bc9b4dd2..e6247e12ecb38ac 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -890,6 +890,9 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
   else if (func.getArmLocallyStreaming())
     llvmFunc->addFnAttr("aarch64_pstate_sm_body");
 
+  if (func.getArmNewZa())
+    llvmFunc->addFnAttr("aarch64_pstate_za_new");
+
   if (auto attr = func.getVscaleRange())
     llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
         getLLVMContext(), attr->getMinRange().getInt(),
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index d81f3c4b1e20c5a..03bc7eec580418a 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -139,6 +139,10 @@ if(MLIR_ENABLE_ROCM_RUNNER)
   list(APPEND MLIR_TEST_DEPENDS mlir_rocm_runtime)
 endif()
 
+if(MLIR_RUN_ARM_SME_TESTS)
+  list(APPEND MLIR_TEST_DEPENDS mlir_arm_sme_runtime)
+endif()
+
 list(APPEND MLIR_TEST_DEPENDS MLIRUnitTests)
 
 if(LLVM_BUILD_EXAMPLES)
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index e7bbe8c0047687d..2ec6f4090dff0c2 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -1,13 +1,13 @@
 // RUN: mlir-opt %s -enable-arm-streaming -verify-diagnostics | FileCheck %s
-// RUN: mlir-opt %s -enable-arm-streaming=mode=locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
-// RUN: mlir-opt %s -enable-arm-streaming=enable-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
 
 // CHECK-LABEL: @arm_streaming
 // CHECK-SAME: attributes {arm_streaming}
 // CHECK-LOCALLY-LABEL: @arm_streaming
 // CHECK-LOCALLY-SAME: attributes {arm_locally_streaming}
 // CHECK-ENABLE-ZA-LABEL: @arm_streaming
-// CHECK-ENABLE-ZA-SAME: attributes {arm_streaming, arm_za}
+// CHECK-ENABLE-ZA-SAME: attributes {arm_new_za, arm_streaming}
 func.func @arm_streaming() { return }
 
 // CHECK-LABEL: @not_arm_streaming
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
index d415b19f6fa94cf..8631721ef61bc77 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
@@ -1,18 +1,16 @@
-// RUN: mlir-opt %s -enable-arm-streaming=enable-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
 // RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA
 // RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING
 
 // CHECK-LABEL: @declaration
 func.func private @declaration()
 
-// CHECK-LABEL: @arm_za
-func.func @arm_za() {
-  // ENABLE-ZA: arm_sme.intr.za.enable
-  // ENABLE-ZA-NEXT: arm_sme.intr.za.disable
-  // ENABLE-ZA-NEXT: return
-  // DISABLE-ZA-NOT: arm_sme.intr.za.enable
-  // DISABLE-ZA-NOT: arm_sme.intr.za.disable
-  // NO-ARM-STREAMING-NOT: arm_sme.intr.za.enable
-  // NO-ARM-STREAMING-NOT: arm_sme.intr.za.disable
-  return
-}
+// ENABLE-ZA-LABEL: @arm_new_za
+// ENABLE-ZA-SAME: attributes {arm_new_za, arm_streaming}
+// DISABLE-ZA-LABEL: @arm_new_za
+// DISABLE-ZA-NOT: arm_new_za
+// DISABLE-ZA-SAME: attributes {arm_streaming}
+// NO-ARM-STREAMING-LABEL: @arm_new_za
+// NO-ARM-STREAMING-NOT: arm_new_za
+// NO-ARM-STREAMING-NOT: arm_streaming
+func.func @arm_new_za() { return }
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 131cbc05a9857e0..8755a2c5064ee82 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -3,14 +3,14 @@
 // RUN:   -test-transform-dialect-erase-schedule \
 // RUN:   -lower-vector-mask \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" \
-// RUN:   -enable-arm-streaming="mode=locally enable-za" \
+// RUN:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // RUN:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=entry -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \
-// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
+// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime | \
 // RUN: FileCheck %s
 
 func.func @entry() {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
index 78f1bede5a6a529..ed4d1db275e043f 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
@@ -1,13 +1,13 @@
 // DEFINE: %{entry_point} = test_load_store_zaq0
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=void \
-// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index eda4d9a090f8d40..27af2f0a5daa618 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -1,13 +1,13 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
-// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index ae5ad9cc2a5e90c..07a026e14b68b36 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -1,13 +1,13 @@
 // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
-// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
 
 // RUN: %{compile}
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 36ce896a4c1bd90..81646b17e12dc48 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -1,13 +1,13 @@
 // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_2x2xf64
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme-f64f64 \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
-// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
 
 // RUN: %{compile}
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
index 65b930115e88895..a88350c14ff2da2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -1,13 +1,13 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
-// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index 92031586b8cfc91..ad55c4a4b11aa75 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -1,11 +1,11 @@
-// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
+// RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // RUN:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:  -march=aarch64 -mattr=+sve,+sme \
 // RUN:  -e entry -entry-point-result=i32 \
-// RUN:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \
+// RUN:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime | \
 // RUN: FileCheck %s
 
 // Integration test demonstrating filling a 32-bit element ZA tile with a
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index adf1d365cb99823..056ee4b954c4f94 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -1,13 +1,13 @@
 // DEFINE: %{entry_point} = za0_d_f64
 // DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=i32 \
-// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
 
 // RUN: %{compile} | %{run} | FileCheck %s --check-prefix=CHECK-ZA0_D
 
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index 455405d923bd664..a9a8060d4323083 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,12 +1,12 @@
 // DEFINE: %{entry_point} = entry
-// DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=locally za-mode=new" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" \
 // DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=i32 \
-// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_sme_runtime
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 628d7ba4b649e51..ab9efafcfbf0f99 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -228,17 +228,6 @@ llvm.func @arm_sme_store(%nxv1i1  : vector<[1]xi1>,
 
 // -----
 
-// CHECK-LABEL: @arm_sme_toggle_za
-llvm.func @arm_sme_toggle_za() {
-  // CHECK: call void @llvm.aarch64.sme.za.enable()
-  "arm_sme.intr.za.enable"() : () -> ()
-  // CHECK: call void @llvm.aarch64.sme.za.disable()
-  "arm_sme.intr.za.disable"() : () -> ()
-  llvm.return
-}
-
-// -----
-
 // CHECK-LABEL: @arm_sme_vector_to_tile_horiz
 llvm.func @arm_sme_vector_to_tile_horiz(%tileslice : i32,
                                         %nxv16i1 : vector<[16]xi1>,
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index da8488373862c36..0512e6bce3a3d94 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -126,6 +126,9 @@ def add_runtime(name):
 if config.enable_cuda_runner:
     tools.extend([add_runtime("mlir_cuda_runtime")])
 
+if config.mlir_run_arm_sme_tests:
+    tools.extend([add_runtime("mlir_arm_sme_runtime")])
+
 # The following tools are optional
 tools.extend(
     [



More information about the Mlir-commits mailing list