[Mlir-commits] [mlir] 1b64ed0 - [mlir][ArmSME] Disallow streaming mode for gathers/scatters (#96209)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 24 03:10:03 PDT 2024


Author: Benjamin Maxwell
Date: 2024-06-24T11:10:00+01:00
New Revision: 1b64ed0e0c7fde1b65d55bfb7954beadc0f60e28

URL: https://github.com/llvm/llvm-project/commit/1b64ed0e0c7fde1b65d55bfb7954beadc0f60e28
DIFF: https://github.com/llvm/llvm-project/commit/1b64ed0e0c7fde1b65d55bfb7954beadc0f60e28.diff

LOG: [mlir][ArmSME] Disallow streaming mode for gathers/scatters (#96209)

Ideally, this would be based on target information (but we don't really
have that), so this currently errs on the side of caution. If possible
gathers/scatters should be lowered regular vector loads/stores before
using invoking enable-arm-streaming.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
    mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
    mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
    mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index c1f016d9ce1f1..dfd64f995546a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -120,10 +120,10 @@ def EnableArmStreaming
            /*default=*/"false",
            "Only apply the selected streaming/ZA modes if the function contains"
            " ops that implement the ArmSMETileOpInterface.">,
-    Option<"ifContainsScalableVectors", "if-contains-scalable-vectors",
+    Option<"ifScalableAndSupported", "if-scalable-and-supported",
            "bool", /*default=*/"false",
            "Only apply the selected streaming/ZA modes if the function contains"
-           " operations that use scalable vector types.">
+           " supported scalable vector operations.">
   ];
   let dependentDialects = ["func::FuncDialect"];
 }

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index fb4bb41d87488..eafdc1de5ef34 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -55,22 +55,33 @@ namespace {
 constexpr StringLiteral
     kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
 
+template <typename... Ops>
+constexpr auto opList() {
+  return std::array{TypeID::get<Ops>()...};
+}
+
+bool isScalableVector(Type type) {
+  if (auto vectorType = dyn_cast<VectorType>(type))
+    return vectorType.isScalable();
+  return false;
+}
+
 struct EnableArmStreamingPass
     : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
   EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
-                         bool ifRequiredByOps, bool ifContainsScalableVectors) {
+                         bool ifRequiredByOps, bool ifScalableAndSupported) {
     this->streamingMode = streamingMode;
     this->zaMode = zaMode;
     this->ifRequiredByOps = ifRequiredByOps;
-    this->ifContainsScalableVectors = ifContainsScalableVectors;
+    this->ifScalableAndSupported = ifScalableAndSupported;
   }
   void runOnOperation() override {
     auto function = getOperation();
 
-    if (ifRequiredByOps && ifContainsScalableVectors) {
+    if (ifRequiredByOps && ifScalableAndSupported) {
       function->emitOpError(
           "enable-arm-streaming: `if-required-by-ops` and "
-          "`if-contains-scalable-vectors` are mutually exclusive");
+          "`if-scalable-and-supported` are mutually exclusive");
       return signalPassFailure();
     }
 
@@ -87,22 +98,27 @@ struct EnableArmStreamingPass
         return;
     }
 
-    if (ifContainsScalableVectors) {
-      bool foundScalableVector = false;
-      auto isScalableVector = [&](Type type) {
-        if (auto vectorType = dyn_cast<VectorType>(type))
-          return vectorType.isScalable();
-        return false;
-      };
+    if (ifScalableAndSupported) {
+      // FIXME: This should be based on target information (i.e., the presence
+      // of FEAT_SME_FA64). This currently errs on the side of caution. If
+      // possible gathers/scatters should be lowered regular vector loads/stores
+      // before invoking this pass.
+      auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
+      bool isCompatibleScalableFunction = false;
       function.walk([&](Operation *op) {
-        if (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
-            llvm::any_of(op->getResultTypes(), isScalableVector)) {
-          foundScalableVector = true;
+        if (llvm::is_contained(disallowedOperations,
+                               op->getName().getTypeID())) {
+          isCompatibleScalableFunction = false;
           return WalkResult::interrupt();
         }
+        if (!isCompatibleScalableFunction &&
+            (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
+             llvm::any_of(op->getResultTypes(), isScalableVector))) {
+          isCompatibleScalableFunction = true;
+        }
         return WalkResult::advance();
       });
-      if (!foundScalableVector)
+      if (!isCompatibleScalableFunction)
         return;
     }
 
@@ -126,7 +142,7 @@ struct EnableArmStreamingPass
 
 std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
     const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
-    bool ifRequiredByOps, bool ifContainsScalableVectors) {
+    bool ifRequiredByOps, bool ifScalableAndSupported) {
   return std::make_unique<EnableArmStreamingPass>(
-      streamingMode, zaMode, ifRequiredByOps, ifContainsScalableVectors);
+      streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
 }

diff  --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
index da70b632d70c4..4e92f814c0ec8 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming-invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -enable-arm-streaming="if-contains-scalable-vectors if-required-by-ops" -verify-diagnostics
+// RUN: mlir-opt %s -enable-arm-streaming="if-scalable-and-supported if-required-by-ops" -verify-diagnostics
 
-// expected-error at below {{enable-arm-streaming: `if-required-by-ops` and `if-contains-scalable-vectors` are mutually exclusive}}
+// expected-error at below {{enable-arm-streaming: `if-required-by-ops` and `if-scalable-and-supported` are mutually exclusive}}
 func.func @test() { return }

diff  --git a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
index 2011802c5c8b2..00b38b86f0d6e 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir
@@ -3,7 +3,7 @@
 // RUN: mlir-opt %s -enable-arm-streaming=streaming-mode=streaming-compatible -verify-diagnostics | FileCheck %s -check-prefix=CHECK-COMPATIBLE
 // RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
 // RUN: mlir-opt %s -enable-arm-streaming=if-required-by-ops -verify-diagnostics | FileCheck %s -check-prefix=IF-REQUIRED
-// RUN: mlir-opt %s -enable-arm-streaming=if-contains-scalable-vectors -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
+// RUN: mlir-opt %s -enable-arm-streaming=if-scalable-and-supported -verify-diagnostics | FileCheck %s -check-prefix=IF-SCALABLE
 
 // CHECK-LABEL: @arm_streaming
 // CHECK-SAME: attributes {arm_streaming}
@@ -53,3 +53,20 @@ func.func @no_scalable_vectors(%vec: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.addf %vec, %vec : vector<4xf32>
   return %0 : vector<4xf32>
 }
+
+// IF-SCALABLE-LABEL: @contains_gather
+// IF-SCALABLE-NOT: arm_streaming
+func.func @contains_gather(%base: memref<?xf32>, %v: vector<[4]xindex>, %mask: vector<[4]xi1>, %pass_thru: vector<[4]xf32>) -> vector<[4]xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<[4]xindex>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+
+// IF-SCALABLE-LABEL: @contains_scatter
+// IF-SCALABLE-NOT: arm_streaming
+func.func @contains_scatter(%base: memref<?xf32>, %v: vector<[4]xindex>,%mask: vector<[4]xi1>, %value: vector<[4]xf32>)
+{
+  %c0 = arith.constant 0 : index
+  vector.scatter %base[%c0][%v], %mask, %value : memref<?xf32>, vector<[4]xindex>, vector<[4]xi1>, vector<[4]xf32>
+  return
+}


        


More information about the Mlir-commits mailing list