[Mlir-commits] [mlir] [mlir][ArmSME] Refine the `EnableArmStreaming` pass (PR #79432)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 25 02:09:50 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

Updates the logic in `EnableArmStreaming` pass so that:
  * the streaming mode is always enabled whenever the pass is run -
    the `onlyIfRequiredByOps` flag no longer applies to the streaming
    mode and is effectively removed;
  * new flag, `enableZAConservatively`, controls whether to enable ZA
    unconditionally based on the `zaMode` flag (default) or
    conditionally - only when SME ops are in present.

This change basically limits the previous behaviour to only apply to the
"ZA array" as opposed to the "streaming mode" + "ZA array". This is
required for cases where we do want to enable the streaming mode even
though there are no SME ops.


---
Full diff: https://github.com/llvm/llvm-project/pull/79432.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+2-3) 
- (modified) mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp (+6-8) 
- (modified) mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir (+12-9) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir (+1-1) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 8d1ba6ed34e805b..5c111373ca0c3fe 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -101,10 +101,9 @@ def EnableArmStreaming
                             "The function uses ZA state. The ZA state may "
                             "be used for input and/or output.")
            )}]>,
-    Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
+    Option<"enableZAConservatively", "enable-za-conservatively", "bool",
            /*default=*/"false",
-           "Only apply the selected streaming/ZA modes if the function "
-           " contains ops that require them.">
+           "Enable ZA iff the function contains ops that require it.">
   ];
   let dependentDialects = ["func::FuncDialect"];
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index 79a6caffb6ee0bf..65113621e9504b9 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -58,25 +58,23 @@ constexpr StringLiteral
 struct EnableArmStreamingPass
     : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
   EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
-                         bool onlyIfRequiredByOps) {
+                         bool enableZAConservatively) {
     this->streamingMode = streamingMode;
     this->zaMode = zaMode;
-    this->onlyIfRequiredByOps = onlyIfRequiredByOps;
+    this->enableZAConservatively = enableZAConservatively;
   }
   void runOnOperation() override {
     auto op = getOperation();
 
-    if (onlyIfRequiredByOps) {
-      bool foundTileOp = false;
+    bool enableZA = !enableZAConservatively;
+    if (enableZAConservatively) {
       op.walk([&](Operation *op) {
         if (llvm::isa<ArmSMETileOpInterface>(op)) {
-          foundTileOp = true;
+          enableZA = true;
           return WalkResult::interrupt();
         }
         return WalkResult::advance();
       });
-      if (!foundTileOp)
-        return;
     }
 
     if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
@@ -91,7 +89,7 @@ struct EnableArmStreamingPass
     // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
     // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
     // supporting this later.
-    if (zaMode != ArmZaMode::Disabled)
+    if (enableZA && zaMode != ArmZaMode::Disabled)
       op->setAttr(stringifyArmZaMode(zaMode), unitAttr);
   }
 };
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 6b58d8fdc41b0e0..fbee73cf616d1b0 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -2,7 +2,7 @@
 // RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
 // RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
 // RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
+// RUN: mlir-opt %s -enable-arm-streaming="za-mode=new-za enable-za-conservatively" -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
 
 // CHECK-LABEL: @arm_streaming
 // CHECK-SAME: attributes {arm_streaming}
@@ -24,17 +24,20 @@ func.func @arm_streaming() { return }
 // CHECK-ENABLE-ZA-SAME: attributes {enable_arm_streaming_ignore}
 func.func @not_arm_streaming() attributes {enable_arm_streaming_ignore} { return }
 
-// CHECK-LABEL: @requires_arm_streaming
+// CHECK-LABEL: @requires_arm_za
 // CHECK-SAME: attributes {arm_streaming}
-// IF-REQUIRED: @requires_arm_streaming
-// IF-REQUIRED-SAME: attributes {arm_streaming}
-func.func @requires_arm_streaming() {
+// IF-REQUIRED: @requires_arm_za
+// IF-REQUIRED-SAME: attributes {arm_new_za, arm_streaming}
+func.func @requires_arm_za() {
   %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   return
 }
 
-// CHECK-LABEL: @does_not_require_arm_streaming
+// CHECK-LABEL: @does_not_require_arm_za
 // CHECK-SAME: attributes {arm_streaming}
-// IF-REQUIRED: @does_not_require_arm_streaming
-// IF-REQUIRED-NOT: arm_streaming
-func.func @does_not_require_arm_streaming() { return }
+// There are no SME Ops inside the function, so the request to enable ZA is
+// ignored.
+// IF-REQUIRED: @does_not_require_arm_za
+// IF-REQUIRED-SAME: attributes {arm_streaming}
+// IF-REQUIRED-NOT: arm_new_za
+func.func @does_not_require_arm_za() { return }
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
index 6e028d5fb83614d..78c386d301e9b16 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -1,7 +1,7 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za enable-za-conservatively" \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index c0c1f55d7ddd1ae..f0f492225b35a4c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -2,7 +2,7 @@
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
 // DEFINE:   -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za enable-za-conservatively" \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
 // DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \

``````````

</details>


https://github.com/llvm/llvm-project/pull/79432


More information about the Mlir-commits mailing list