[Mlir-commits] [mlir] [mlir][ArmSME] Add option to only enable streaming mode/ZA if required (PR #73931)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 30 04:42:27 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This adds a `only-if-required-by-ops` flag to the `enable-arm-streaming` pass. This flag defaults to `false` (which preserves the original behaviour), however, if set to `true` the pass will only add the selected ZA/streaming mode to functions that contain ops that implement `ArmSMETileOpInterface`.

This simplifies enabling these modes, as we can now first try lowering ops to ArmSME, then only if we succeed, add the relevant function attributes.

---
Full diff: https://github.com/llvm/llvm-project/pull/73931.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (+1-1) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+5-1) 
- (modified) mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp (+22-3) 
- (modified) mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir (+16) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 11a7385fe311dd3..21a97e9cbc794c9 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -27,7 +27,7 @@ namespace arm_sme {
 /// Pass to enable Armv9 Streaming SVE mode.
 std::unique_ptr<Pass> createEnableArmStreamingPass(
     const ArmStreamingMode = ArmStreamingMode::Streaming,
-    const ArmZaMode = ArmZaMode::Disabled);
+    const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
 
 /// Pass that allocates tile IDs to ArmSME operations.
 std::unique_ptr<Pass> createTileAllocationPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 3253b47e62abddb..7b9c74e0b8f60e7 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -73,7 +73,11 @@ def EnableArmStreaming
                             "new-za",
                             "The function has ZA state. The ZA state is "
                             "created on entry and destroyed on exit.")
-           )}]>
+           )}]>,
+    Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
+           /*default=*/"false",
+           "Only apply the selected streaming/ZA modes if the function "
+           " contains ops that require them.">
   ];
   let dependentDialects = ["func::FuncDialect"];
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index c3a1a1c9a3fb49e..79a6caffb6ee0bf 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -33,6 +33,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
 
@@ -56,12 +57,28 @@ constexpr StringLiteral
 
 struct EnableArmStreamingPass
     : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
-  EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode) {
+  EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
+                         bool onlyIfRequiredByOps) {
     this->streamingMode = streamingMode;
     this->zaMode = zaMode;
+    this->onlyIfRequiredByOps = onlyIfRequiredByOps;
   }
   void runOnOperation() override {
     auto op = getOperation();
+
+    if (onlyIfRequiredByOps) {
+      bool foundTileOp = false;
+      op.walk([&](Operation *op) {
+        if (llvm::isa<ArmSMETileOpInterface>(op)) {
+          foundTileOp = true;
+          return WalkResult::interrupt();
+        }
+        return WalkResult::advance();
+      });
+      if (!foundTileOp)
+        return;
+    }
+
     if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
         streamingMode == ArmStreamingMode::Disabled)
       return;
@@ -81,6 +98,8 @@ struct EnableArmStreamingPass
 } // namespace
 
 std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
-    const ArmStreamingMode streamingMode, const ArmZaMode zaMode) {
-  return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode);
+    const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
+    bool onlyIfRequiredByOps) {
+  return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode,
+                                                  onlyIfRequiredByOps);
 }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 70119b08c3e91aa..b1188acbc0b2d74 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s -enable-arm-streaming -verify-diagnostics | FileCheck %s
 // RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
 // RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
 
 // CHECK-LABEL: @arm_streaming
 // CHECK-SAME: attributes {arm_streaming}
@@ -17,3 +18,18 @@ func.func @arm_streaming() { return }
 // CHECK-ENABLE-ZA-LABEL: @not_arm_streaming
 // CHECK-ENABLE-ZA-SAME: attributes {enable_arm_streaming_ignore}
 func.func @not_arm_streaming() attributes {enable_arm_streaming_ignore} { return }
+
+// CHECK-LABEL: @requires_arm_streaming
+// CHECK-SAME: attributes {arm_streaming}
+// IF-REQUIRED: @requires_arm_streaming
+// IF-REQUIRED-SAME: attributes {arm_streaming}
+func.func @requires_arm_streaming() {
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
+  return
+}
+
+// CHECK-LABEL: @does_not_require_arm_streaming
+// CHECK-SAME: attributes {arm_streaming}
+// IF-REQUIRED: @does_not_require_arm_streaming
+// IF-REQUIRED-NOT: arm_streaming
+func.func @does_not_require_arm_streaming() { return }

``````````

</details>


https://github.com/llvm/llvm-project/pull/73931


More information about the Mlir-commits mailing list