[Mlir-commits] [mlir] [mlir][ArmSME] Add option to only enable streaming mode for scalable code (PR #94759)

Benjamin Maxwell llvmlistbot at llvm.org
Fri Jun 7 07:31:27 PDT 2024


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/94759

This adds a new option `-enable-arm-streaming=if-contains-scalable-vectors`, which only applies the selected streaming/ZA modes if the function contains scalable vector types.

As a NFC this patch also removes the `only-` prefix from the `if-required-by-ops` mode.

>From 5d078454cad9d4b5b85f9479aeec8effab2c9cb0 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 7 Jun 2024 14:07:36 +0000
Subject: [PATCH] [mlir][ArmSME] Add option to only enable streaming mode for
 scalable code

This adds a new option `-enable-arm-streaming=if-contains-scalable-vectors`,
which only applies the selected streaming/ZA modes if the function
contains scalable vector types.

As a NFC this patch also removes the `only-` prefix from the
`if-required-by-ops` mode.
---
 .../mlir/Dialect/ArmSME/Transforms/Passes.h   |  3 +-
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  | 10 ++--
 .../ArmSME/Transforms/EnableArmStreaming.cpp  | 49 ++++++++++++++-----
 .../ArmSME/enable-arm-streaming-invalid.mlir  |  4 ++
 .../Dialect/ArmSME/enable-arm-streaming.mlir  | 17 ++++++-
 .../ArmSME/multi-tile-matmul-mixed-types.mlir |  2 +-
 .../lib/Dialect/ArmSME/TestLowerToArmSME.cpp  |  2 +-
 7 files changed, 69 insertions(+), 18 deletions(-)
 create mode 100644 mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 156744ba57e7b..167e5b787d1af 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -27,7 +27,8 @@ namespace arm_sme {
 /// Pass to enable Armv9 Streaming SVE mode.
 std::unique_ptr<Pass> createEnableArmStreamingPass(
     const ArmStreamingMode = ArmStreamingMode::Streaming,
-    const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
+    const ArmZaMode = ArmZaMode::Disabled, bool ifRequiredByOps = false,
+    bool ifContainsScalableVectors = false);
 
 /// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
 /// variants.
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 869a031d6cae8..8aba121432bba 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -116,10 +116,14 @@ def EnableArmStreaming
                             "not be used for input and/or output and the "
                             "function must return with ZA unchanged")
            )}]>,
-    Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
+    Option<"ifRequiredByOps", "if-required-by-ops", "bool",
            /*default=*/"false",
-           "Only apply the selected streaming/ZA modes if the function "
-           " contains ops that require them.">
+           "Apply the selected streaming/ZA modes if the function contains ops "
+           "that require them.">,
+    Option<"ifContainsScalableVectors", "if-contains-scalable-vectors",
+           "bool", /*default=*/"false",
+           "Apply the selected streaming/ZA modes if the function contains "
+           "operations that use scalable vector types.">
   ];
   let dependentDialects = ["func::FuncDialect"];
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index 79a6caffb6ee0..fb4bb41d87488 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -58,17 +58,25 @@ constexpr StringLiteral
 struct EnableArmStreamingPass
     : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
   EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
-                         bool onlyIfRequiredByOps) {
+                         bool ifRequiredByOps, bool ifContainsScalableVectors) {
     this->streamingMode = streamingMode;
     this->zaMode = zaMode;
-    this->onlyIfRequiredByOps = onlyIfRequiredByOps;
+    this->ifRequiredByOps = ifRequiredByOps;
+    this->ifContainsScalableVectors = ifContainsScalableVectors;
   }
   void runOnOperation() override {
-    auto op = getOperation();
+    auto function = getOperation();
 
-    if (onlyIfRequiredByOps) {
+    if (ifRequiredByOps && ifContainsScalableVectors) {
+      function->emitOpError(
+          "enable-arm-streaming: `if-required-by-ops` and "
+          "`if-contains-scalable-vectors` are mutually exclusive");
+      return signalPassFailure();
+    }
+
+    if (ifRequiredByOps) {
       bool foundTileOp = false;
-      op.walk([&](Operation *op) {
+      function.walk([&](Operation *op) {
         if (llvm::isa<ArmSMETileOpInterface>(op)) {
           foundTileOp = true;
           return WalkResult::interrupt();
@@ -79,27 +87,46 @@ struct EnableArmStreamingPass
         return;
     }
 
-    if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
+    if (ifContainsScalableVectors) {
+      bool foundScalableVector = false;
+      auto isScalableVector = [&](Type type) {
+        if (auto vectorType = dyn_cast<VectorType>(type))
+          return vectorType.isScalable();
+        return false;
+      };
+      function.walk([&](Operation *op) {
+        if (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
+            llvm::any_of(op->getResultTypes(), isScalableVector)) {
+          foundScalableVector = true;
+          return WalkResult::interrupt();
+        }
+        return WalkResult::advance();
+      });
+      if (!foundScalableVector)
+        return;
+    }
+
+    if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
         streamingMode == ArmStreamingMode::Disabled)
       return;
 
     auto unitAttr = UnitAttr::get(&getContext());
 
-    op->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
+    function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
 
     // The pass currently only supports enabling ZA when in streaming-mode, but
     // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
     // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
     // supporting this later.
     if (zaMode != ArmZaMode::Disabled)
-      op->setAttr(stringifyArmZaMode(zaMode), unitAttr);
+      function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
   }
 };
 } // namespace
 
 std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
     const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
-    bool onlyIfRequiredByOps) {
-  return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode,
-                                                  onlyIfRequiredByOps);
+    bool ifRequiredByOps, bool ifContainsScalableVectors) {
+  return std::make_unique<EnableArmStreamingPass>(
+      streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors);
 }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
new file mode 100644
index 0000000000000..da70b632d70c4
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
@@ -0,0 +1,4 @@
+// RUN: mlir-opt %s -enable-arm-streaming="if-contains-scalable-vectors if-required-by-ops" -verify-diagnostics
+
+// expected-error at below {{enable-arm-streaming: `if-required-by-ops` and `if-contains-scalable-vectors` are mutually exclusive}}
+func.func @test() { return }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 6b58d8fdc41b0..2011802c5c8b2 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -2,7 +2,8 @@
 // RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
 // RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
 // RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
+// RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
+// RUN: mlir-opt %s -enable-arm-streaming=if-contains-scalable-vectors -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
 
 // CHECK-LABEL: @arm_streaming
 // CHECK-SAME: attributes {arm_streaming}
@@ -38,3 +39,17 @@ func.func @requires_arm_streaming() {
 // IF-REQUIRED: @does_not_require_arm_streaming
 // IF-REQUIRED-NOT: arm_streaming
 func.func @does_not_require_arm_streaming() { return }
+
+// IF-SCALABLE-LABEL: @contains_scalable_vectors
+// IF-SCALABLE-SAME: attributes {arm_streaming}
+func.func @contains_scalable_vectors(%vec: vector<[4]xf32>) -> vector<[4]xf32> {
+  %0 = arith.addf %vec, %vec : vector<[4]xf32>
+  return %0 : vector<[4]xf32>
+}
+
+// IF-SCALABLE-LABEL: @no_scalable_vectors
+// IF-SCALABLE-NOT: arm_streaming
+func.func @no_scalable_vectors(%vec: vector<4xf32>) -> vector<4xf32> {
+  %0 = arith.addf %vec, %vec : vector<4xf32>
+  return %0 : vector<4xf32>
+}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
index 10ffed2688178..aabd9d2ce788e 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
@@ -4,7 +4,7 @@
 // RUN:   -arm-sme-vector-legalization -canonicalize -cse \
 // RUN:   -convert-vector-to-arm-sme -arm-sme-outer-product-fusion \
 // RUN:   -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
-// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za if-required-by-ops" \
 // RUN:   -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
 // RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
diff --git a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
index d3dabaf200fdc..a220791969d53 100644
--- a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
+++ b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
@@ -74,7 +74,7 @@ void buildTestLowerToArmSME(OpPassManager &pm,
   // Enable streaming-mode and ZA.
   pm.addPass(arm_sme::createEnableArmStreamingPass(
       arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
-      /*onlyIfRequiredByOps=*/true));
+      /*ifRequiredByOps=*/true));
 
   // Convert SCF to CF (required for ArmSME tile allocation).
   pm.addPass(createConvertSCFToCFPass());



More information about the Mlir-commits mailing list