[Mlir-commits] [mlir] [mlir][Arm] Fix invalid rewrite pattern API violations (PR #78246)

Matthias Springer llvmlistbot at llvm.org
Tue Jan 16 03:06:22 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/78246

>From 43cfc9176fb37e59a30195064ba54ad095ce4284 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 16 Jan 2024 09:44:11 +0000
Subject: [PATCH 1/3] [mlir][Arm] Fix invalid rewrite pattern API violations

This commit fixes rewrite pattern API violations:
* Rewrite pattern must return "failure" if the IR was not modified.
* In-place op modifications must be communicated to the rewriter
  (`updateRootInPlace`).

This commit fixes `test/Dialect/ArmSVE/legalize-vector-storage.mlir`,
`test/Dialect/ArmSME/vector-ops-to-llvm.mlir`,
`test/Dialect/ArmSME/tile-allocation-invalid.mlir`,
`test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir`,
`test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir`,
`test/Conversion/ArmSMEToLLVM/unsupported.mlir` when running with
`MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`.
---
 .../ArmSME/Transforms/TileAllocation.cpp      | 21 +++++++++++++------
 .../Transforms/LegalizeVectorStorage.cpp      |  3 ++-
 2 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 49ea6bb5f8614e7..b4630c834ff2428 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -232,14 +232,11 @@ struct AssignTileIDsPattern
         static_cast<TileMask>(getDiscardableIntAttr(kTilesInUseAttr));
     auto tileId = allocateTileId(*tileType, tilesInUse);
     bool tileIsInMemory = failed(tileId);
-    if (!tileIsInMemory)
-      setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
-    else {
+    if (tileIsInMemory) {
       // If we could not find a real tile ID, use an in-memory tile ID (ID >=
       // 16). A later pass will insert the necessary spills and reloads.
       tileId =
           getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase);
-      setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
       tileOp->emitWarning(
           "failed to allocate SME virtual tile to operation, all tile "
           "operations will go through memory, expect degraded performance");
@@ -263,14 +260,26 @@ struct AssignTileIDsPattern
     SetVector<Operation *> dependantOps;
     findDependantOps(tileOp->getResult(0), dependantOps);
     auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
-    rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
     for (auto *op : dependantOps) {
       if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
         auto currentTileId = dependantTileOp.getTileId();
         if (currentTileId && unsigned(currentTileId.getInt()) != tileId)
           return dependantTileOp.emitOpError(
               "already assigned different SME virtual tile!");
-        dependantTileOp.setTileId(tileIDAttr);
+      }
+    }
+
+    // Rewrite IR.
+    if (!tileIsInMemory)
+      setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
+    else {
+      setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
+    }
+    rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
+    for (auto *op : dependantOps) {
+      if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
+        rewriter.updateRootInPlace(
+            dependantTileOp, [&]() { dependantTileOp.setTileId(tileIDAttr); });
       }
     }
 
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index bee1f3659753b78..ed695db0372f28f 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -106,7 +106,8 @@ struct RelaxScalableVectorAllocaAlignment
 
     // Set alignment based on the defaults for SVE vectors and predicates.
     unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
-    allocaOp.setAlignment(aligment);
+    rewriter.updateRootInPlace(allocaOp,
+                               [&]() { allocaOp.setAlignment(aligment); });
 
     return success();
   }

>From 18c5ef0f0abbc785310a0a48c6cc5b1f33f3f9d6 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at matthiasspringer.de>
Date: Tue, 16 Jan 2024 12:06:02 +0100
Subject: [PATCH 2/3] Update
 mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp

Co-authored-by: Benjamin Maxwell <macdue at dueutil.tech>
---
 mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index b4630c834ff2428..30df0f0db80281a 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -272,9 +272,8 @@ struct AssignTileIDsPattern
     // Rewrite IR.
     if (!tileIsInMemory)
       setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
-    else {
+    else
       setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
-    }
     rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
     for (auto *op : dependantOps) {
       if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {

>From 711f2bdb925f58301c51891947ac11de1f4fa555 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at matthiasspringer.de>
Date: Tue, 16 Jan 2024 12:06:15 +0100
Subject: [PATCH 3/3] Update
 mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp

Co-authored-by: Benjamin Maxwell <macdue at dueutil.tech>
---
 mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 30df0f0db80281a..4d49efecbe05c3d 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -278,7 +278,7 @@ struct AssignTileIDsPattern
     for (auto *op : dependantOps) {
       if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
         rewriter.updateRootInPlace(
-            dependantTileOp, [&]() { dependantTileOp.setTileId(tileIDAttr); });
+            dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); });
       }
     }
 



More information about the Mlir-commits mailing list