[Mlir-commits] [mlir] [mlir][ArmSME] Disallow streaming mode for gathers/scatters (PR #96209)
Benjamin Maxwell
llvmlistbot at llvm.org
Fri Jun 21 09:24:30 PDT 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/96209
>From 9f871c49538d6c2d3532f8ea557d505cbd743ed2 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 20 Jun 2024 15:54:55 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Don't allow enabling streaming mode for
gathers/scatters
Ideally, this would be based on target information (but we don't really
have that), so this currently errs on the side of caution. If possible
gathers/scatters should be lowered regular vector loads/stores before
using invoking enable-arm-streaming.
---
.../mlir/Dialect/ArmSME/Transforms/Passes.td | 4 +-
.../ArmSME/Transforms/EnableArmStreaming.cpp | 49 ++++++++++++-------
.../ArmSME/enable-arm-streaming-invalid.mlir | 4 +-
.../Dialect/ArmSME/enable-arm-streaming.mlir | 19 ++++++-
4 files changed, 54 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index c1f016d9ce1f1..1beae8cdfb336 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -120,10 +120,10 @@ def EnableArmStreaming
/*default=*/"false",
"Only apply the selected streaming/ZA modes if the function contains"
" ops that implement the ArmSMETileOpInterface.">,
- Option<"ifContainsScalableVectors", "if-contains-scalable-vectors",
+ Option<"ifCompatibleAndScalable", "if-compatible-and-scalable",
"bool", /*default=*/"false",
"Only apply the selected streaming/ZA modes if the function contains"
- " operations that use scalable vector types.">
+ " compatible scalable vector operations.">
];
let dependentDialects = ["func::FuncDialect"];
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index fb4bb41d87488..cc798d5f91507 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -55,22 +55,33 @@ namespace {
constexpr StringLiteral
kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
+template <typename... Ops>
+constexpr auto opList() {
+ return std::array{TypeID::get<Ops>()...};
+}
+
+bool isScalableVector(Type type) {
+ if (auto vectorType = dyn_cast<VectorType>(type))
+ return vectorType.isScalable();
+ return false;
+}
+
struct EnableArmStreamingPass
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
- bool ifRequiredByOps, bool ifContainsScalableVectors) {
+ bool ifRequiredByOps, bool ifCompatibleAndScalable) {
this->streamingMode = streamingMode;
this->zaMode = zaMode;
this->ifRequiredByOps = ifRequiredByOps;
- this->ifContainsScalableVectors = ifContainsScalableVectors;
+ this->ifCompatibleAndScalable = ifCompatibleAndScalable;
}
void runOnOperation() override {
auto function = getOperation();
- if (ifRequiredByOps && ifContainsScalableVectors) {
+ if (ifRequiredByOps && ifCompatibleAndScalable) {
function->emitOpError(
"enable-arm-streaming: `if-required-by-ops` and "
- "`if-contains-scalable-vectors` are mutually exclusive");
+ "`if-compatible-and-scalable` are mutually exclusive");
return signalPassFailure();
}
@@ -87,22 +98,26 @@ struct EnableArmStreamingPass
return;
}
- if (ifContainsScalableVectors) {
- bool foundScalableVector = false;
- auto isScalableVector = [&](Type type) {
- if (auto vectorType = dyn_cast<VectorType>(type))
- return vectorType.isScalable();
- return false;
- };
+ if (ifCompatibleAndScalable) {
+ // FIXME: This should be based on target information. This currently errs
+ // on the side of caution. If possible gathers/scatters should be lowered
+ // regular vector loads/stores before invoking this pass.
+ auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
+ bool isCompatibleScalableFunction = false;
function.walk([&](Operation *op) {
- if (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
- llvm::any_of(op->getResultTypes(), isScalableVector)) {
- foundScalableVector = true;
+ if (llvm::is_contained(disallowedOperations,
+ op->getName().getTypeID())) {
+ isCompatibleScalableFunction = false;
return WalkResult::interrupt();
}
+ if (!isCompatibleScalableFunction &&
+ (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
+ llvm::any_of(op->getResultTypes(), isScalableVector))) {
+ isCompatibleScalableFunction = true;
+ }
return WalkResult::advance();
});
- if (!foundScalableVector)
+ if (!isCompatibleScalableFunction)
return;
}
@@ -126,7 +141,7 @@ struct EnableArmStreamingPass
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
- bool ifRequiredByOps, bool ifContainsScalableVectors) {
+ bool ifRequiredByOps, bool ifCompatibleAndScalable) {
return std::make_unique<EnableArmStreamingPass>(
- streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors);
+ streamingMode, zaMode, ifRequiredByOps, ifCompatibleAndScalable);
}
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
index da70b632d70c4..859de24597c92 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -enable-arm-streaming="if-contains-scalable-vectors if-required-by-ops" -verify-diagnostics
+// RUN: mlir-opt %s -enable-arm-streaming="if-compatible-and-scalable if-required-by-ops" -verify-diagnostics
-// expected-error at below {{enable-arm-streaming: `if-required-by-ops` and `if-contains-scalable-vectors` are mutually exclusive}}
+// expected-error at below {{enable-arm-streaming: `if-required-by-ops` and `if-compatible-and-scalable` are mutually exclusive}}
func.func @test() { return }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 2011802c5c8b2..6d9375c50fa91 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -3,7 +3,7 @@
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
// RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
-// RUN: mlir-opt %s -enable-arm-streaming=if-contains-scalable-vectors -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
+// RUN: mlir-opt %s -enable-arm-streaming=if-compatible-and-scalable -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
// CHECK-LABEL: @arm_streaming
// CHECK-SAME: attributes {arm_streaming}
@@ -53,3 +53,20 @@ func.func @no_scalable_vectors(%vec: vector<4xf32>) -> vector<4xf32> {
%0 = arith.addf %vec, %vec : vector<4xf32>
return %0 : vector<4xf32>
}
+
+// IF-SCALABLE-LABEL: @contains_gather
+// IF-SCALABLE-NOT: arm_streaming
+func.func @contains_gather(%base: memref<?xf32>, %v: vector<[4]xindex>, %mask: vector<[4]xi1>, %pass_thru: vector<[4]xf32>) -> vector<[4]xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<[4]xindex>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+
+// IF-SCALABLE-LABEL: @contains_scatter
+// IF-SCALABLE-NOT: arm_streaming
+func.func @contains_scatter(%base: memref<?xf32>, %v: vector<[4]xindex>,%mask: vector<[4]xi1>, %value: vector<[4]xf32>)
+{
+ %c0 = arith.constant 0 : index
+ vector.scatter %base[%c0][%v], %mask, %value : memref<?xf32>, vector<[4]xindex>, vector<[4]xi1>, vector<[4]xf32>
+ return
+}
>From bddaf3b6564a8495b8568fb3940256051f62eed2 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 21 Jun 2024 16:22:47 +0000
Subject: [PATCH 2/2] Fixups
---
.../mlir/Dialect/ArmSME/Transforms/Passes.td | 4 ++--
.../ArmSME/Transforms/EnableArmStreaming.cpp | 14 +++++++-------
.../ArmSME/enable-arm-streaming-invalid.mlir | 4 ++--
mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir | 2 +-
4 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 1beae8cdfb336..dfd64f995546a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -120,10 +120,10 @@ def EnableArmStreaming
/*default=*/"false",
"Only apply the selected streaming/ZA modes if the function contains"
" ops that implement the ArmSMETileOpInterface.">,
- Option<"ifCompatibleAndScalable", "if-compatible-and-scalable",
+ Option<"ifScalableAndSupported", "if-scalable-and-supported",
"bool", /*default=*/"false",
"Only apply the selected streaming/ZA modes if the function contains"
- " compatible scalable vector operations.">
+ " supported scalable vector operations.">
];
let dependentDialects = ["func::FuncDialect"];
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index cc798d5f91507..87e1c51656865 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -69,19 +69,19 @@ bool isScalableVector(Type type) {
struct EnableArmStreamingPass
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
- bool ifRequiredByOps, bool ifCompatibleAndScalable) {
+ bool ifRequiredByOps, bool ifScalableAndSupported) {
this->streamingMode = streamingMode;
this->zaMode = zaMode;
this->ifRequiredByOps = ifRequiredByOps;
- this->ifCompatibleAndScalable = ifCompatibleAndScalable;
+ this->ifScalableAndSupported = ifScalableAndSupported;
}
void runOnOperation() override {
auto function = getOperation();
- if (ifRequiredByOps && ifCompatibleAndScalable) {
+ if (ifRequiredByOps && ifScalableAndSupported) {
function->emitOpError(
"enable-arm-streaming: `if-required-by-ops` and "
- "`if-compatible-and-scalable` are mutually exclusive");
+ "`if-scalable-and-supported` are mutually exclusive");
return signalPassFailure();
}
@@ -98,7 +98,7 @@ struct EnableArmStreamingPass
return;
}
- if (ifCompatibleAndScalable) {
+ if (ifScalableAndSupported) {
// FIXME: This should be based on target information. This currently errs
// on the side of caution. If possible gathers/scatters should be lowered
// regular vector loads/stores before invoking this pass.
@@ -141,7 +141,7 @@ struct EnableArmStreamingPass
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
- bool ifRequiredByOps, bool ifCompatibleAndScalable) {
+ bool ifRequiredByOps, bool ifScalableAndSupported) {
return std::make_unique<EnableArmStreamingPass>(
- streamingMode, zaMode, ifRequiredByOps, ifCompatibleAndScalable);
+ streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
}
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
index 859de24597c92..4e92f814c0ec8 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -enable-arm-streaming="if-compatible-and-scalable if-required-by-ops" -verify-diagnostics
+// RUN: mlir-opt %s -enable-arm-streaming="if-scalable-and-supported if-required-by-ops" -verify-diagnostics
-// expected-error at below {{enable-arm-streaming: `if-required-by-ops` and `if-compatible-and-scalable` are mutually exclusive}}
+// expected-error at below {{enable-arm-streaming: `if-required-by-ops` and `if-scalable-and-supported` are mutually exclusive}}
func.func @test() { return }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 6d9375c50fa91..00b38b86f0d6e 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -3,7 +3,7 @@
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
// RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
-// RUN: mlir-opt %s -enable-arm-streaming=if-compatible-and-scalable -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
+// RUN: mlir-opt %s -enable-arm-streaming=if-scalable-and-supported -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
// CHECK-LABEL: @arm_streaming
// CHECK-SAME: attributes {arm_streaming}
More information about the Mlir-commits
mailing list