[Mlir-commits] [mlir] [MLIR][NFC] Add fast path to fused loc flattening. (PR #75312)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 13 02:22:20 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Benjamin Chetioui (bchetioui)

<details>
<summary>Changes</summary>

This is a follow-up on [PR 75218](https://github.com/llvm/llvm-project/pull/75218) that avoids reconstructing a fused loc in the `FlattenFusedLocationRecursively` helper when there has been no change.

---
Full diff: https://github.com/llvm/llvm-project/pull/75312.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/FoldUtils.cpp (+28-17) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 056a681718e12..aa1e1ce01777d 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -340,28 +340,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
 /// child fused locations are the same---this to avoid breaking cases where
 /// metadata matter.
 static Location FlattenFusedLocationRecursively(const Location loc) {
-  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
-    SetVector<Location> flattenedLocs;
-    Attribute metadata = fusedLoc.getMetadata();
-
-    for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
-      Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
-      auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
-
-      if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
-                                flattenedFusedLoc.getMetadata() == metadata)) {
-        ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
-        flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
-      } else {
-        flattenedLocs.insert(flattenedLoc);
+  auto fusedLoc = dyn_cast<FusedLoc>(loc);
+  if (!fusedLoc) return loc;
+
+  SetVector<Location> flattenedLocs;
+  Attribute metadata = fusedLoc.getMetadata();
+  ArrayRef<Location> unflattenedLocs = fusedLoc.getLocations();
+  bool hasAnyNestedLocChanged = false;
+
+  for (const Location &unflattenedLoc : unflattenedLocs) {
+    Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
+
+    auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
+    if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
+                              flattenedFusedLoc.getMetadata() == metadata)) {
+      hasAnyNestedLocChanged = true;
+      ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
+      flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
+    } else {
+      if (flattenedLoc != unflattenedLoc) {
+        hasAnyNestedLocChanged = true;
       }
+
+      flattenedLocs.insert(flattenedLoc);
     }
+  }
 
-    return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
-                         fusedLoc.getMetadata());
+  if (!hasAnyNestedLocChanged &&
+      unflattenedLocs.size() == flattenedLocs.size()) {
+    return loc;
   }
 
-  return loc;
+  return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
+                        fusedLoc.getMetadata());
 }
 
 void OperationFolder::appendFoldedLocation(Operation *retainedOp,

``````````

</details>


https://github.com/llvm/llvm-project/pull/75312


More information about the Mlir-commits mailing list