[Mlir-commits] [mlir] [MLIR][NFC] Add fast path to fused loc flattening. (PR #75312)

Benjamin Chetioui llvmlistbot at llvm.org
Wed Dec 13 02:21:54 PST 2023


https://github.com/bchetioui created https://github.com/llvm/llvm-project/pull/75312

This is a follow-up on [PR 75218](https://github.com/llvm/llvm-project/pull/75218) that avoids reconstructing a fused loc in the `FlattenFusedLocationRecursively` helper when there has been no change.

>From a980bd57a9989268293c541c825fd12ffd949173 Mon Sep 17 00:00:00 2001
From: Benjamin Chetioui <bchetioui at google.com>
Date: Wed, 13 Dec 2023 10:17:25 +0000
Subject: [PATCH] [MLIR][NFC] Add fast path to fused loc flattening.

This is a follow-up on
[PR 75218](https://github.com/llvm/llvm-project/pull/75218) that avoids
reconstructing a fused loc in the `FlattenFusedLocationRecursively`
helper when there has been no change.
---
 mlir/lib/Transforms/Utils/FoldUtils.cpp | 45 +++++++++++++++----------
 1 file changed, 28 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 056a681718e121..aa1e1ce01777db 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -340,28 +340,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
 /// child fused locations are the same---this to avoid breaking cases where
 /// metadata matter.
 static Location FlattenFusedLocationRecursively(const Location loc) {
-  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
-    SetVector<Location> flattenedLocs;
-    Attribute metadata = fusedLoc.getMetadata();
-
-    for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
-      Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
-      auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
-
-      if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
-                                flattenedFusedLoc.getMetadata() == metadata)) {
-        ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
-        flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
-      } else {
-        flattenedLocs.insert(flattenedLoc);
+  auto fusedLoc = dyn_cast<FusedLoc>(loc);
+  if (!fusedLoc) return loc;
+
+  SetVector<Location> flattenedLocs;
+  Attribute metadata = fusedLoc.getMetadata();
+  ArrayRef<Location> unflattenedLocs = fusedLoc.getLocations();
+  bool hasAnyNestedLocChanged = false;
+
+  for (const Location &unflattenedLoc : unflattenedLocs) {
+    Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
+
+    auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
+    if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
+                              flattenedFusedLoc.getMetadata() == metadata)) {
+      hasAnyNestedLocChanged = true;
+      ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
+      flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
+    } else {
+      if (flattenedLoc != unflattenedLoc) {
+        hasAnyNestedLocChanged = true;
       }
+
+      flattenedLocs.insert(flattenedLoc);
     }
+  }
 
-    return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
-                         fusedLoc.getMetadata());
+  if (!hasAnyNestedLocChanged &&
+      unflattenedLocs.size() == flattenedLocs.size()) {
+    return loc;
   }
 
-  return loc;
+  return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
+                        fusedLoc.getMetadata());
 }
 
 void OperationFolder::appendFoldedLocation(Operation *retainedOp,



More information about the Mlir-commits mailing list