[Mlir-commits] [mlir] [mlir][ArmSME] Refine the `EnableArmStreaming` pass (PR #79432)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 25 02:09:50 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
Updates the logic in `EnableArmStreaming` pass so that:
* the streaming mode is always enabled whenever the pass is run -
the `onlyIfRequiredByOps` flag no longer applies to the streaming
mode and is effectively removed;
* new flag, `enableZAConservatively`, controls whether to enable ZA
unconditionally based on the `zaMode` flag (default) or
conditionally - only when SME ops are in present.
This change basically limits the previous behaviour to only apply to the
"ZA array" as opposed to the "streaming mode" + "ZA array". This is
required for cases where we do want to enable the streaming mode even
though there are no SME ops.
---
Full diff: https://github.com/llvm/llvm-project/pull/79432.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+2-3)
- (modified) mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp (+6-8)
- (modified) mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir (+12-9)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir (+1-1)
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 8d1ba6ed34e805b..5c111373ca0c3fe 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -101,10 +101,9 @@ def EnableArmStreaming
"The function uses ZA state. The ZA state may "
"be used for input and/or output.")
)}]>,
- Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
+ Option<"enableZAConservatively", "enable-za-conservatively", "bool",
/*default=*/"false",
- "Only apply the selected streaming/ZA modes if the function "
- " contains ops that require them.">
+ "Enable ZA iff the function contains ops that require it.">
];
let dependentDialects = ["func::FuncDialect"];
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index 79a6caffb6ee0bf..65113621e9504b9 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -58,25 +58,23 @@ constexpr StringLiteral
struct EnableArmStreamingPass
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
- bool onlyIfRequiredByOps) {
+ bool enableZAConservatively) {
this->streamingMode = streamingMode;
this->zaMode = zaMode;
- this->onlyIfRequiredByOps = onlyIfRequiredByOps;
+ this->enableZAConservatively = enableZAConservatively;
}
void runOnOperation() override {
auto op = getOperation();
- if (onlyIfRequiredByOps) {
- bool foundTileOp = false;
+ bool enableZA = !enableZAConservatively;
+ if (enableZAConservatively) {
op.walk([&](Operation *op) {
if (llvm::isa<ArmSMETileOpInterface>(op)) {
- foundTileOp = true;
+ enableZA = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
- if (!foundTileOp)
- return;
}
if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
@@ -91,7 +89,7 @@ struct EnableArmStreamingPass
// ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
// streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
// supporting this later.
- if (zaMode != ArmZaMode::Disabled)
+ if (enableZA && zaMode != ArmZaMode::Disabled)
op->setAttr(stringifyArmZaMode(zaMode), unitAttr);
}
};
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 6b58d8fdc41b0e0..fbee73cf616d1b0 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -2,7 +2,7 @@
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
+// RUN: mlir-opt %s -enable-arm-streaming="za-mode=new-za enable-za-conservatively" -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
// CHECK-LABEL: @arm_streaming
// CHECK-SAME: attributes {arm_streaming}
@@ -24,17 +24,20 @@ func.func @arm_streaming() { return }
// CHECK-ENABLE-ZA-SAME: attributes {enable_arm_streaming_ignore}
func.func @not_arm_streaming() attributes {enable_arm_streaming_ignore} { return }
-// CHECK-LABEL: @requires_arm_streaming
+// CHECK-LABEL: @requires_arm_za
// CHECK-SAME: attributes {arm_streaming}
-// IF-REQUIRED: @requires_arm_streaming
-// IF-REQUIRED-SAME: attributes {arm_streaming}
-func.func @requires_arm_streaming() {
+// IF-REQUIRED: @requires_arm_za
+// IF-REQUIRED-SAME: attributes {arm_new_za, arm_streaming}
+func.func @requires_arm_za() {
%tile = arm_sme.get_tile : vector<[4]x[4]xi32>
return
}
-// CHECK-LABEL: @does_not_require_arm_streaming
+// CHECK-LABEL: @does_not_require_arm_za
// CHECK-SAME: attributes {arm_streaming}
-// IF-REQUIRED: @does_not_require_arm_streaming
-// IF-REQUIRED-NOT: arm_streaming
-func.func @does_not_require_arm_streaming() { return }
+// There are no SME Ops inside the function, so the request to enable ZA is
+// ignored.
+// IF-REQUIRED: @does_not_require_arm_za
+// IF-REQUIRED-SAME: attributes {arm_streaming}
+// IF-REQUIRED-NOT: arm_new_za
+func.func @does_not_require_arm_za() { return }
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
index 6e028d5fb83614d..78c386d301e9b16 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -1,7 +1,7 @@
// DEFINE: %{entry_point} = entry
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za enable-za-conservatively" \
// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index c0c1f55d7ddd1ae..f0f492225b35a4c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -2,7 +2,7 @@
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za enable-za-conservatively" \
// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
``````````
</details>
https://github.com/llvm/llvm-project/pull/79432
More information about the Mlir-commits
mailing list