[Mlir-commits] [mlir] 1b64ed0 - [mlir][ArmSME] Disallow streaming mode for gathers/scatters (#96209)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 24 03:10:03 PDT 2024
Author: Benjamin Maxwell
Date: 2024-06-24T11:10:00+01:00
New Revision: 1b64ed0e0c7fde1b65d55bfb7954beadc0f60e28
URL: https://github.com/llvm/llvm-project/commit/1b64ed0e0c7fde1b65d55bfb7954beadc0f60e28
DIFF: https://github.com/llvm/llvm-project/commit/1b64ed0e0c7fde1b65d55bfb7954beadc0f60e28.diff
LOG: [mlir][ArmSME] Disallow streaming mode for gathers/scatters (#96209)
Ideally, this would be based on target information (but we don't really
have that), so this currently errs on the side of caution. If possible
gathers/scatters should be lowered regular vector loads/stores before
using invoking enable-arm-streaming.
Added:
Modified:
mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index c1f016d9ce1f1..dfd64f995546a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -120,10 +120,10 @@ def EnableArmStreaming
/*default=*/"false",
"Only apply the selected streaming/ZA modes if the function contains"
" ops that implement the ArmSMETileOpInterface.">,
- Option<"ifContainsScalableVectors", "if-contains-scalable-vectors",
+ Option<"ifScalableAndSupported", "if-scalable-and-supported",
"bool", /*default=*/"false",
"Only apply the selected streaming/ZA modes if the function contains"
- " operations that use scalable vector types.">
+ " supported scalable vector operations.">
];
let dependentDialects = ["func::FuncDialect"];
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index fb4bb41d87488..eafdc1de5ef34 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -55,22 +55,33 @@ namespace {
constexpr StringLiteral
kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
+template <typename... Ops>
+constexpr auto opList() {
+ return std::array{TypeID::get<Ops>()...};
+}
+
+bool isScalableVector(Type type) {
+ if (auto vectorType = dyn_cast<VectorType>(type))
+ return vectorType.isScalable();
+ return false;
+}
+
struct EnableArmStreamingPass
: public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
- bool ifRequiredByOps, bool ifContainsScalableVectors) {
+ bool ifRequiredByOps, bool ifScalableAndSupported) {
this->streamingMode = streamingMode;
this->zaMode = zaMode;
this->ifRequiredByOps = ifRequiredByOps;
- this->ifContainsScalableVectors = ifContainsScalableVectors;
+ this->ifScalableAndSupported = ifScalableAndSupported;
}
void runOnOperation() override {
auto function = getOperation();
- if (ifRequiredByOps && ifContainsScalableVectors) {
+ if (ifRequiredByOps && ifScalableAndSupported) {
function->emitOpError(
"enable-arm-streaming: `if-required-by-ops` and "
- "`if-contains-scalable-vectors` are mutually exclusive");
+ "`if-scalable-and-supported` are mutually exclusive");
return signalPassFailure();
}
@@ -87,22 +98,27 @@ struct EnableArmStreamingPass
return;
}
- if (ifContainsScalableVectors) {
- bool foundScalableVector = false;
- auto isScalableVector = [&](Type type) {
- if (auto vectorType = dyn_cast<VectorType>(type))
- return vectorType.isScalable();
- return false;
- };
+ if (ifScalableAndSupported) {
+ // FIXME: This should be based on target information (i.e., the presence
+ // of FEAT_SME_FA64). This currently errs on the side of caution. If
+ // possible gathers/scatters should be lowered regular vector loads/stores
+ // before invoking this pass.
+ auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
+ bool isCompatibleScalableFunction = false;
function.walk([&](Operation *op) {
- if (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
- llvm::any_of(op->getResultTypes(), isScalableVector)) {
- foundScalableVector = true;
+ if (llvm::is_contained(disallowedOperations,
+ op->getName().getTypeID())) {
+ isCompatibleScalableFunction = false;
return WalkResult::interrupt();
}
+ if (!isCompatibleScalableFunction &&
+ (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
+ llvm::any_of(op->getResultTypes(), isScalableVector))) {
+ isCompatibleScalableFunction = true;
+ }
return WalkResult::advance();
});
- if (!foundScalableVector)
+ if (!isCompatibleScalableFunction)
return;
}
@@ -126,7 +142,7 @@ struct EnableArmStreamingPass
std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
- bool ifRequiredByOps, bool ifContainsScalableVectors) {
+ bool ifRequiredByOps, bool ifScalableAndSupported) {
return std::make_unique<EnableArmStreamingPass>(
- streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors);
+ streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
}
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
index da70b632d70c4..4e92f814c0ec8 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -enable-arm-streaming="if-contains-scalable-vectors if-required-by-ops" -verify-diagnostics
+// RUN: mlir-opt %s -enable-arm-streaming="if-scalable-and-supported if-required-by-ops" -verify-diagnostics
-// expected-error at below {{enable-arm-streaming: `if-required-by-ops` and `if-contains-scalable-vectors` are mutually exclusive}}
+// expected-error at below {{enable-arm-streaming: `if-required-by-ops` and `if-scalable-and-supported` are mutually exclusive}}
func.func @test() { return }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 2011802c5c8b2..00b38b86f0d6e 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -3,7 +3,7 @@
// RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
// RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
-// RUN: mlir-opt %s -enable-arm-streaming=if-contains-scalable-vectors -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
+// RUN: mlir-opt %s -enable-arm-streaming=if-scalable-and-supported -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
// CHECK-LABEL: @arm_streaming
// CHECK-SAME: attributes {arm_streaming}
@@ -53,3 +53,20 @@ func.func @no_scalable_vectors(%vec: vector<4xf32>) -> vector<4xf32> {
%0 = arith.addf %vec, %vec : vector<4xf32>
return %0 : vector<4xf32>
}
+
+// IF-SCALABLE-LABEL: @contains_gather
+// IF-SCALABLE-NOT: arm_streaming
+func.func @contains_gather(%base: memref<?xf32>, %v: vector<[4]xindex>, %mask: vector<[4]xi1>, %pass_thru: vector<[4]xf32>) -> vector<[4]xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<[4]xindex>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+
+// IF-SCALABLE-LABEL: @contains_scatter
+// IF-SCALABLE-NOT: arm_streaming
+func.func @contains_scatter(%base: memref<?xf32>, %v: vector<[4]xindex>,%mask: vector<[4]xi1>, %value: vector<[4]xf32>)
+{
+ %c0 = arith.constant 0 : index
+ vector.scatter %base[%c0][%v], %mask, %value : memref<?xf32>, vector<[4]xindex>, vector<[4]xi1>, vector<[4]xf32>
+ return
+}
More information about the Mlir-commits
mailing list