[Mlir-commits] [mlir] [mlir][ArmSME] Add option to only enable streaming mode for scalable code (PR #94759)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 7 07:31:57 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
This adds a new option `-enable-arm-streaming=if-contains-scalable-vectors`, which only applies the selected streaming/ZA modes if the function contains scalable vector types.
As a NFC this patch also removes the `only-` prefix from the `if-required-by-ops` mode.
---
Full diff: https://github.com/llvm/llvm-project/pull/94759.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (+2-1)
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+7-3)
- (modified) mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp (+38-11)
- (added) mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir (+4)
- (modified) mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir (+16-1)
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir (+1-1)
- (modified) mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 156744ba57e7b..167e5b787d1af 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -27,7 +27,8 @@ namespace arm_sme {
/// Pass to enable Armv9 Streaming SVE mode.
std::unique_ptr<Pass> createEnableArmStreamingPass(
const ArmStreamingMode = ArmStreamingMode::Streaming,
- const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
+ const ArmZaMode = ArmZaMode::Disabled, bool ifRequiredByOps = false,
+ bool ifContainsScalableVectors = false);
/// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
/// variants.
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 869a031d6cae8..8aba121432bba 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -116,10 +116,14 @@ def EnableArmStreaming
"not be used for input and/or output and the "
"function must return with ZA unchanged")
)}]>,
- Option<"onlyIfRequiredByOps", "only-if-required-by-ops", "bool",
+ Option<"ifRequiredByOps", "if-required-by-ops", "bool",
/*default=*/"false",
- "Only apply the selected streaming/ZA modes if the function "
- " contains ops that require them.">
+ "Apply the selected streaming/ZA modes if the function contains ops "
+ "that require them.">,
+ Option<"ifContainsScalableVectors", "if-contains-scalable-vectors",
+ "bool", /*default=*/"false",
+ "Apply the selected streaming/ZA modes if the function contains "
+ "operations that use scalable vector types.">
];
let dependentDialects = ["func::FuncDialect"];
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index 79a6caffb6ee0..fb4bb41d87488 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -58,17 +58,25 @@ constexpr StringLiteral
struct EnableArmStreamingPass
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
- bool onlyIfRequiredByOps) {
+ bool ifRequiredByOps, bool ifContainsScalableVectors) {
this->streamingMode = streamingMode;
this->zaMode = zaMode;
- this->onlyIfRequiredByOps = onlyIfRequiredByOps;
+ this->ifRequiredByOps = ifRequiredByOps;
+ this->ifContainsScalableVectors = ifContainsScalableVectors;
}
void runOnOperation() override {
- auto op = getOperation();
+ auto function = getOperation();
- if (onlyIfRequiredByOps) {
+ if (ifRequiredByOps && ifContainsScalableVectors) {
+ function->emitOpError(
+ "enable-arm-streaming: `if-required-by-ops` and "
+ "`if-contains-scalable-vectors` are mutually exclusive");
+ return signalPassFailure();
+ }
+
+ if (ifRequiredByOps) {
bool foundTileOp = false;
- op.walk([&](Operation *op) {
+ function.walk([&](Operation *op) {
if (llvm::isa<ArmSMETileOpInterface>(op)) {
foundTileOp = true;
return WalkResult::interrupt();
@@ -79,27 +87,46 @@ struct EnableArmStreamingPass
return;
}
- if (op->getAttr(kEnableArmStreamingIgnoreAttr) ||
+ if (ifContainsScalableVectors) {
+ bool foundScalableVector = false;
+ auto isScalableVector = [&](Type type) {
+ if (auto vectorType = dyn_cast<VectorType>(type))
+ return vectorType.isScalable();
+ return false;
+ };
+ function.walk([&](Operation *op) {
+ if (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
+ llvm::any_of(op->getResultTypes(), isScalableVector)) {
+ foundScalableVector = true;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ if (!foundScalableVector)
+ return;
+ }
+
+ if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
streamingMode == ArmStreamingMode::Disabled)
return;
auto unitAttr = UnitAttr::get(&getContext());
- op->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
+ function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
// The pass currently only supports enabling ZA when in streaming-mode, but
// ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
// streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
// supporting this later.
if (zaMode != ArmZaMode::Disabled)
- op->setAttr(stringifyArmZaMode(zaMode), unitAttr);
+ function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
}
};
} // namespace
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
- bool onlyIfRequiredByOps) {
- return std::make_unique<EnableArmStreamingPass>(streamingMode, zaMode,
- onlyIfRequiredByOps);
+ bool ifRequiredByOps, bool ifContainsScalableVectors) {
+ return std::make_unique<EnableArmStreamingPass>(
+ streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors);
}
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
new file mode 100644
index 0000000000000..da70b632d70c4
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
@@ -0,0 +1,4 @@
+// RUN: mlir-opt %s -enable-arm-streaming="if-contains-scalable-vectors if-required-by-ops" -verify-diagnostics
+
+// expected-error at below {{enable-arm-streaming: `if-required-by-ops` and `if-contains-scalable-vectors` are mutually exclusive}}
+func.func @test() { return }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 6b58d8fdc41b0..2011802c5c8b2 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -2,7 +2,8 @@
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=only-if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
+// RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
+// RUN: mlir-opt %s -enable-arm-streaming=if-contains-scalable-vectors -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
// CHECK-LABEL: @arm_streaming
// CHECK-SAME: attributes {arm_streaming}
@@ -38,3 +39,17 @@ func.func @requires_arm_streaming() {
// IF-REQUIRED: @does_not_require_arm_streaming
// IF-REQUIRED-NOT: arm_streaming
func.func @does_not_require_arm_streaming() { return }
+
+// IF-SCALABLE-LABEL: @contains_scalable_vectors
+// IF-SCALABLE-SAME: attributes {arm_streaming}
+func.func @contains_scalable_vectors(%vec: vector<[4]xf32>) -> vector<[4]xf32> {
+ %0 = arith.addf %vec, %vec : vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+
+// IF-SCALABLE-LABEL: @no_scalable_vectors
+// IF-SCALABLE-NOT: arm_streaming
+func.func @no_scalable_vectors(%vec: vector<4xf32>) -> vector<4xf32> {
+ %0 = arith.addf %vec, %vec : vector<4xf32>
+ return %0 : vector<4xf32>
+}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
index 10ffed2688178..aabd9d2ce788e 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul-mixed-types.mlir
@@ -4,7 +4,7 @@
// RUN: -arm-sme-vector-legalization -canonicalize -cse \
// RUN: -convert-vector-to-arm-sme -arm-sme-outer-product-fusion \
// RUN: -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
-// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za if-required-by-ops" \
// RUN: -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
// RUN: -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
diff --git a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
index d3dabaf200fdc..a220791969d53 100644
--- a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
+++ b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
@@ -74,7 +74,7 @@ void buildTestLowerToArmSME(OpPassManager &pm,
// Enable streaming-mode and ZA.
pm.addPass(arm_sme::createEnableArmStreamingPass(
arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
- /*onlyIfRequiredByOps=*/true));
+ /*ifRequiredByOps=*/true));
// Convert SCF to CF (required for ArmSME tile allocation).
pm.addPass(createConvertSCFToCFPass());
``````````
</details>
https://github.com/llvm/llvm-project/pull/94759
More information about the Mlir-commits
mailing list