[Mlir-commits] [mlir] [mlir][ArmSME] Don't allow enabling streaming mode for gathers/scatters (PR #96209)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 20 09:02:25 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
Ideally, this would be based on target information (but we don't really have that), so this currently errs on the side of caution. If possible gathers/scatters should be lowered regular vector loads/stores before using invoking enable-arm-streaming.
---
Full diff: https://github.com/llvm/llvm-project/pull/96209.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+2-2)
- (modified) mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp (+32-17)
- (modified) mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir (+2-2)
- (modified) mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir (+18-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index c1f016d9ce1f1..1beae8cdfb336 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -120,10 +120,10 @@ def EnableArmStreaming
/*default=*/"false",
"Only apply the selected streaming/ZA modes if the function contains"
" ops that implement the ArmSMETileOpInterface.">,
- Option<"ifContainsScalableVectors", "if-contains-scalable-vectors",
+ Option<"ifCompatibleAndScalable", "if-compatible-and-scalable",
"bool", /*default=*/"false",
"Only apply the selected streaming/ZA modes if the function contains"
- " operations that use scalable vector types.">
+ " compatible scalable vector operations.">
];
let dependentDialects = ["func::FuncDialect"];
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index fb4bb41d87488..cc798d5f91507 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -55,22 +55,33 @@ namespace {
constexpr StringLiteral
kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
+template <typename... Ops>
+constexpr auto opList() {
+ return std::array{TypeID::get<Ops>()...};
+}
+
+bool isScalableVector(Type type) {
+ if (auto vectorType = dyn_cast<VectorType>(type))
+ return vectorType.isScalable();
+ return false;
+}
+
struct EnableArmStreamingPass
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
- bool ifRequiredByOps, bool ifContainsScalableVectors) {
+ bool ifRequiredByOps, bool ifCompatibleAndScalable) {
this->streamingMode = streamingMode;
this->zaMode = zaMode;
this->ifRequiredByOps = ifRequiredByOps;
- this->ifContainsScalableVectors = ifContainsScalableVectors;
+ this->ifCompatibleAndScalable = ifCompatibleAndScalable;
}
void runOnOperation() override {
auto function = getOperation();
- if (ifRequiredByOps && ifContainsScalableVectors) {
+ if (ifRequiredByOps && ifCompatibleAndScalable) {
function->emitOpError(
"enable-arm-streaming: `if-required-by-ops` and "
- "`if-contains-scalable-vectors` are mutually exclusive");
+ "`if-compatible-and-scalable` are mutually exclusive");
return signalPassFailure();
}
@@ -87,22 +98,26 @@ struct EnableArmStreamingPass
return;
}
- if (ifContainsScalableVectors) {
- bool foundScalableVector = false;
- auto isScalableVector = [&](Type type) {
- if (auto vectorType = dyn_cast<VectorType>(type))
- return vectorType.isScalable();
- return false;
- };
+ if (ifCompatibleAndScalable) {
+ // FIXME: This should be based on target information. This currently errs
+ // on the side of caution. If possible gathers/scatters should be lowered
+ // regular vector loads/stores before invoking this pass.
+ auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
+ bool isCompatibleScalableFunction = false;
function.walk([&](Operation *op) {
- if (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
- llvm::any_of(op->getResultTypes(), isScalableVector)) {
- foundScalableVector = true;
+ if (llvm::is_contained(disallowedOperations,
+ op->getName().getTypeID())) {
+ isCompatibleScalableFunction = false;
return WalkResult::interrupt();
}
+ if (!isCompatibleScalableFunction &&
+ (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
+ llvm::any_of(op->getResultTypes(), isScalableVector))) {
+ isCompatibleScalableFunction = true;
+ }
return WalkResult::advance();
});
- if (!foundScalableVector)
+ if (!isCompatibleScalableFunction)
return;
}
@@ -126,7 +141,7 @@ struct EnableArmStreamingPass
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
- bool ifRequiredByOps, bool ifContainsScalableVectors) {
+ bool ifRequiredByOps, bool ifCompatibleAndScalable) {
return std::make_unique<EnableArmStreamingPass>(
- streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors);
+ streamingMode, zaMode, ifRequiredByOps, ifCompatibleAndScalable);
}
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
index da70b632d70c4..859de24597c92 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -enable-arm-streaming="if-contains-scalable-vectors if-required-by-ops" -verify-diagnostics
+// RUN: mlir-opt %s -enable-arm-streaming="if-compatible-and-scalable if-required-by-ops" -verify-diagnostics
-// expected-error at below {{enable-arm-streaming: `if-required-by-ops` and `if-contains-scalable-vectors` are mutually exclusive}}
+// expected-error at below {{enable-arm-streaming: `if-required-by-ops` and `if-compatible-and-scalable` are mutually exclusive}}
func.func @test() { return }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 2011802c5c8b2..6d9375c50fa91 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -3,7 +3,7 @@
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
// RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
-// RUN: mlir-opt %s -enable-arm-streaming=if-contains-scalable-vectors -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
+// RUN: mlir-opt %s -enable-arm-streaming=if-compatible-and-scalable -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
// CHECK-LABEL: @arm_streaming
// CHECK-SAME: attributes {arm_streaming}
@@ -53,3 +53,20 @@ func.func @no_scalable_vectors(%vec: vector<4xf32>) -> vector<4xf32> {
%0 = arith.addf %vec, %vec : vector<4xf32>
return %0 : vector<4xf32>
}
+
+// IF-SCALABLE-LABEL: @contains_gather
+// IF-SCALABLE-NOT: arm_streaming
+func.func @contains_gather(%base: memref<?xf32>, %v: vector<[4]xindex>, %mask: vector<[4]xi1>, %pass_thru: vector<[4]xf32>) -> vector<[4]xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<[4]xindex>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+
+// IF-SCALABLE-LABEL: @contains_scatter
+// IF-SCALABLE-NOT: arm_streaming
+func.func @contains_scatter(%base: memref<?xf32>, %v: vector<[4]xindex>,%mask: vector<[4]xi1>, %value: vector<[4]xf32>)
+{
+ %c0 = arith.constant 0 : index
+ vector.scatter %base[%c0][%v], %mask, %value : memref<?xf32>, vector<[4]xindex>, vector<[4]xi1>, vector<[4]xf32>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/96209
More information about the Mlir-commits
mailing list