[Mlir-commits] [mlir] [mlir][ArmSME] Add option to only enable streaming mode/ZA if required (PR #73931)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 30 04:42:27 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
This adds a `only-if-required-by-ops` flag to the `enable-arm-streaming` pass. This flag defaults to `false` (which preserves the original behaviour), however, if set to `true` the pass will only add the selected ZA/streaming mode to functions that contain ops that implement `ArmSMETileOpInterface`.
This simplifies enabling these modes, as we can now first try lowering ops to ArmSME, then only if we succeed, add the relevant function attributes.
---
Full diff: https://github.com/llvm/llvm-project/pull/73931.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (+1-1)
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+5-1)
- (modified) mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp (+22-3)
- (modified) mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir (+16)
``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 11a7385fe311dd3..21a97e9cbc794c9 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -27,7 +27,7 @@ namespace arm_sme {
/// Pass to enable Armv9 Streaming SVE mode.
std::unique_ptr<Pass> createEnableArmStreamingPass(
const ArmStreamingMode = ArmStreamingMode::Streaming,
- const ArmZaMode = ArmZaMode::Disabled);
+ const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
/// Pass that allocates tile IDs to ArmSME operations.
std::unique_ptr<Pass> createTileAllocationPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 3253b47e62abddb..7b9c74e0b8f60e7 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -73,7 +73,11 @@ def EnableArmStreaming
"new-za",
"The function has ZA state. The ZA state is "
"created on entry and destroyed on exit.")
- )}]>
+ )}]>,
+ Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
+ /*default=*/"false",
+ "Only apply the selected streaming/ZA modes if the function "
+ " contains ops that require them.">
];
let dependentDialects = ["func::FuncDialect"];
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index c3a1a1c9a3fb49e..79a6caffb6ee0bf 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -33,6 +33,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
@@ -56,12 +57,28 @@ constexpr StringLiteral
struct EnableArmStreamingPass
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
- EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode) {
+ EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
+ bool onlyIfRequiredByOps) {
this->streamingMode = streamingMode;
this->zaMode = zaMode;
+ this->onlyIfRequiredByOps = onlyIfRequiredByOps;
}
void runOnOperation() override {
auto op = getOperation();
+
+ if (onlyIfRequiredByOps) {
+ bool foundTileOp = false;
+ op.walk([&](Operation *op) {
+ if (llvm::isa<ArmSMETileOpInterface>(op)) {
+ foundTileOp = true;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ if (!foundTileOp)
+ return;
+ }
+
if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
streamingMode == ArmStreamingMode::Disabled)
return;
@@ -81,6 +98,8 @@ struct EnableArmStreamingPass
} // namespace
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
- const ArmStreamingMode streamingMode, const ArmZaMode zaMode) {
- return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode);
+ const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
+ bool onlyIfRequiredByOps) {
+ return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode,
+ onlyIfRequiredByOps);
}
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 70119b08c3e91aa..b1188acbc0b2d74 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -enable-arm-streaming -verify-diagnostics | FileCheck %s
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
// CHECK-LABEL: @arm_streaming
// CHECK-SAME: attributes {arm_streaming}
@@ -17,3 +18,18 @@ func.func @arm_streaming() { return }
// CHECK-ENABLE-ZA-LABEL: @not_arm_streaming
// CHECK-ENABLE-ZA-SAME: attributes {enable_arm_streaming_ignore}
func.func @not_arm_streaming() attributes {enable_arm_streaming_ignore} { return }
+
+// CHECK-LABEL: @requires_arm_streaming
+// CHECK-SAME: attributes {arm_streaming}
+// IF-REQUIRED: @requires_arm_streaming
+// IF-REQUIRED-SAME: attributes {arm_streaming}
+func.func @requires_arm_streaming() {
+ %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
+ return
+}
+
+// CHECK-LABEL: @does_not_require_arm_streaming
+// CHECK-SAME: attributes {arm_streaming}
+// IF-REQUIRED: @does_not_require_arm_streaming
+// IF-REQUIRED-NOT: arm_streaming
+func.func @does_not_require_arm_streaming() { return }
``````````
</details>
https://github.com/llvm/llvm-project/pull/73931
More information about the Mlir-commits
mailing list