[Mlir-commits] [mlir] 6fe3cd5 - [MLIR][NFC] Add fast path to fused loc flattening. (#75312)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 13 03:40:45 PST 2023


Author: Benjamin Chetioui
Date: 2023-12-13T12:40:41+01:00
New Revision: 6fe3cd54670cae52dae92a38a6d28f450fe8f321

URL: https://github.com/llvm/llvm-project/commit/6fe3cd54670cae52dae92a38a6d28f450fe8f321
DIFF: https://github.com/llvm/llvm-project/commit/6fe3cd54670cae52dae92a38a6d28f450fe8f321.diff

LOG: [MLIR][NFC] Add fast path to fused loc flattening. (#75312)

This is a follow-up on [PR
75218](https://github.com/llvm/llvm-project/pull/75218) that avoids
reconstructing a fused loc in the `FlattenFusedLocationRecursively`
helper when there has been no change.

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/FoldUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 056a681718e121..136c4d2216b865 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -340,28 +340,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
 /// child fused locations are the same---this to avoid breaking cases where
 /// metadata matter.
 static Location FlattenFusedLocationRecursively(const Location loc) {
-  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
-    SetVector<Location> flattenedLocs;
-    Attribute metadata = fusedLoc.getMetadata();
-
-    for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
-      Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
-      auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
-
-      if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
-                                flattenedFusedLoc.getMetadata() == metadata)) {
-        ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
-        flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
-      } else {
-        flattenedLocs.insert(flattenedLoc);
-      }
+  auto fusedLoc = dyn_cast<FusedLoc>(loc);
+  if (!fusedLoc)
+    return loc;
+
+  SetVector<Location> flattenedLocs;
+  Attribute metadata = fusedLoc.getMetadata();
+  ArrayRef<Location> unflattenedLocs = fusedLoc.getLocations();
+  bool hasAnyNestedLocChanged = false;
+
+  for (const Location &unflattenedLoc : unflattenedLocs) {
+    Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
+
+    auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
+    if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
+                              flattenedFusedLoc.getMetadata() == metadata)) {
+      hasAnyNestedLocChanged = true;
+      ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
+      flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
+    } else {
+      if (flattenedLoc != unflattenedLoc)
+        hasAnyNestedLocChanged = true;
+
+      flattenedLocs.insert(flattenedLoc);
     }
+  }
 
-    return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
-                         fusedLoc.getMetadata());
+  if (!hasAnyNestedLocChanged &&
+      unflattenedLocs.size() == flattenedLocs.size()) {
+    return loc;
   }
 
-  return loc;
+  return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
+                       fusedLoc.getMetadata());
 }
 
 void OperationFolder::appendFoldedLocation(Operation *retainedOp,


        


More information about the Mlir-commits mailing list