[Mlir-commits] [mlir] [mlir][ArmSME][NFC] Check early for unsupported mask ops (PR #135955)

Matthias Springer llvmlistbot at llvm.org
Wed Apr 16 05:19:30 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/135955

This is to avoid rollbacks in the dialect conversion, which are expensive.

Note: This is in preparation of the One-Shot Dialect Conversion refactoring.

>From 8c1997eae3121689e844b75e241e4c8e6e700b63 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 16 Apr 2025 14:18:09 +0200
Subject: [PATCH] [mlir][ArmSME] Check early for unsupported mask ops

This is to avoid rollbacks in the dialect conversion, which are expensive.

Note: This is in preparation of the One-Shot Dialect Conversion refactoring.
---
 mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 17 ++++++++++++-----
 1 file changed, 12 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 6ed29903ea407..630414030d98b 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -77,11 +77,6 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
   Value upperBound;
   if (mask) {
     auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
-    if (!createMaskOp)
-      return rewriter.notifyMatchFailure(
-          loc, "unsupported mask op, only 'vector.create_mask' is "
-               "currently supported");
-
     auto maskDim0 = createMaskOp.getOperands()[0];
     auto maskDim1 = createMaskOp.getOperands()[1];
 
@@ -184,6 +179,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 
     Value initTile;
     if (mask) {
+      if (!mask.getDefiningOp<vector::CreateMaskOp>())
+        return rewriter.notifyMatchFailure(
+            loc, "unsupported mask op, only 'vector.create_mask' is "
+                 "currently supported");
       auto padOp = tileLoadOp.getPadding();
       assert(padOp && "expected padding when masking!");
 
@@ -373,6 +372,14 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
 
   LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
                                 PatternRewriter &rewriter) const override {
+    if (Value mask = tileStoreOp.getMask()) {
+      if (!mask.getDefiningOp<vector::CreateMaskOp>())
+        return rewriter.notifyMatchFailure(
+            tileStoreOp.getLoc(),
+            "unsupported mask op, only 'vector.create_mask' is "
+            "currently supported");
+    }
+
     // Create a loop that stores each active ZA tile slice from memory.
     return createLoadStoreForOverTileSlices(
         rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),



More information about the Mlir-commits mailing list