[Mlir-commits] [mlir] [mlir][ArmSME] Don't allow enabling streaming mode for gathers/scatters (PR #96209)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Jun 20 09:01:41 PDT 2024


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/96209

Ideally, this would be based on target information (but we don't really have that), so this currently errs on the side of caution. If possible gathers/scatters should be lowered regular vector loads/stores before using invoking enable-arm-streaming.

>From 9f871c49538d6c2d3532f8ea557d505cbd743ed2 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 20 Jun 2024 15:54:55 +0000
Subject: [PATCH] [mlir][ArmSME] Don't allow enabling streaming mode for
 gathers/scatters

Ideally, this would be based on target information (but we don't really
have that), so this currently errs on the side of caution. If possible
gathers/scatters should be lowered regular vector loads/stores before
using invoking enable-arm-streaming.
---
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |  4 +-
 .../ArmSME/Transforms/EnableArmStreaming.cpp  | 49 ++++++++++++-------
 .../ArmSME/enable-arm-streaming-invalid.mlir  |  4 +-
 .../Dialect/ArmSME/enable-arm-streaming.mlir  | 19 ++++++-
 4 files changed, 54 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index c1f016d9ce1f1..1beae8cdfb336 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -120,10 +120,10 @@ def EnableArmStreaming
            /*default=*/"false",
            "Only apply the selected streaming/ZA modes if the function contains"
            " ops that implement the ArmSMETileOpInterface.">,
-    Option<"ifContainsScalableVectors", "if-contains-scalable-vectors",
+    Option<"ifCompatibleAndScalable", "if-compatible-and-scalable",
            "bool", /*default=*/"false",
            "Only apply the selected streaming/ZA modes if the function contains"
-           " operations that use scalable vector types.">
+           " compatible scalable vector operations.">
   ];
   let dependentDialects = ["func::FuncDialect"];
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index fb4bb41d87488..cc798d5f91507 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -55,22 +55,33 @@ namespace {
 constexpr StringLiteral
     kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
 
+template <typename... Ops>
+constexpr auto opList() {
+  return std::array{TypeID::get<Ops>()...};
+}
+
+bool isScalableVector(Type type) {
+  if (auto vectorType = dyn_cast<VectorType>(type))
+    return vectorType.isScalable();
+  return false;
+}
+
 struct EnableArmStreamingPass
     : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
   EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
-                         bool ifRequiredByOps, bool ifContainsScalableVectors) {
+                         bool ifRequiredByOps, bool ifCompatibleAndScalable) {
     this->streamingMode = streamingMode;
     this->zaMode = zaMode;
     this->ifRequiredByOps = ifRequiredByOps;
-    this->ifContainsScalableVectors = ifContainsScalableVectors;
+    this->ifCompatibleAndScalable = ifCompatibleAndScalable;
   }
   void runOnOperation() override {
     auto function = getOperation();
 
-    if (ifRequiredByOps && ifContainsScalableVectors) {
+    if (ifRequiredByOps && ifCompatibleAndScalable) {
       function->emitOpError(
           "enable-arm-streaming: `if-required-by-ops` and "
-          "`if-contains-scalable-vectors` are mutually exclusive");
+          "`if-compatible-and-scalable` are mutually exclusive");
       return signalPassFailure();
     }
 
@@ -87,22 +98,26 @@ struct EnableArmStreamingPass
         return;
     }
 
-    if (ifContainsScalableVectors) {
-      bool foundScalableVector = false;
-      auto isScalableVector = [&](Type type) {
-        if (auto vectorType = dyn_cast<VectorType>(type))
-          return vectorType.isScalable();
-        return false;
-      };
+    if (ifCompatibleAndScalable) {
+      // FIXME: This should be based on target information. This currently errs
+      // on the side of caution. If possible gathers/scatters should be lowered
+      // regular vector loads/stores before invoking this pass.
+      auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
+      bool isCompatibleScalableFunction = false;
       function.walk([&](Operation *op) {
-        if (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
-            llvm::any_of(op->getResultTypes(), isScalableVector)) {
-          foundScalableVector = true;
+        if (llvm::is_contained(disallowedOperations,
+                               op->getName().getTypeID())) {
+          isCompatibleScalableFunction = false;
           return WalkResult::interrupt();
         }
+        if (!isCompatibleScalableFunction &&
+            (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
+             llvm::any_of(op->getResultTypes(), isScalableVector))) {
+          isCompatibleScalableFunction = true;
+        }
         return WalkResult::advance();
       });
-      if (!foundScalableVector)
+      if (!isCompatibleScalableFunction)
         return;
     }
 
@@ -126,7 +141,7 @@ struct EnableArmStreamingPass
 
 std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
     const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
-    bool ifRequiredByOps, bool ifContainsScalableVectors) {
+    bool ifRequiredByOps, bool ifCompatibleAndScalable) {
   return std::make_unique<EnableArmStreamingPass>(
-      streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors);
+      streamingMode, zaMode, ifRequiredByOps, ifCompatibleAndScalable);
 }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
index da70b632d70c4..859de24597c92 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -enable-arm-streaming="if-contains-scalable-vectors if-required-by-ops" -verify-diagnostics
+// RUN: mlir-opt %s -enable-arm-streaming="if-compatible-and-scalable if-required-by-ops" -verify-diagnostics
 
-// expected-error at below {{enable-arm-streaming: `if-required-by-ops` and `if-contains-scalable-vectors` are mutually exclusive}}
+// expected-error at below {{enable-arm-streaming: `if-required-by-ops` and `if-compatible-and-scalable` are mutually exclusive}}
 func.func @test() { return }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 2011802c5c8b2..6d9375c50fa91 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -3,7 +3,7 @@
 // RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
 // RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
 // RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
-// RUN: mlir-opt %s -enable-arm-streaming=if-contains-scalable-vectors -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
+// RUN: mlir-opt %s -enable-arm-streaming=if-compatible-and-scalable -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
 
 // CHECK-LABEL: @arm_streaming
 // CHECK-SAME: attributes {arm_streaming}
@@ -53,3 +53,20 @@ func.func @no_scalable_vectors(%vec: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.addf %vec, %vec : vector<4xf32>
   return %0 : vector<4xf32>
 }
+
+// IF-SCALABLE-LABEL: @contains_gather
+// IF-SCALABLE-NOT: arm_streaming
+func.func @contains_gather(%base: memref<?xf32>, %v: vector<[4]xindex>, %mask: vector<[4]xi1>, %pass_thru: vector<[4]xf32>) -> vector<[4]xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<[4]xindex>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+
+// IF-SCALABLE-LABEL: @contains_scatter
+// IF-SCALABLE-NOT: arm_streaming
+func.func @contains_scatter(%base: memref<?xf32>, %v: vector<[4]xindex>,%mask: vector<[4]xi1>, %value: vector<[4]xf32>)
+{
+  %c0 = arith.constant 0 : index
+  vector.scatter %base[%c0][%v], %mask, %value : memref<?xf32>, vector<[4]xindex>, vector<[4]xi1>, vector<[4]xf32>
+  return
+}



More information about the Mlir-commits mailing list