[Mlir-commits] [mlir] [MLIR] Flatten fused locations when merging constants. (PR #75218)

Benjamin Chetioui llvmlistbot at llvm.org
Tue Dec 12 12:02:43 PST 2023


https://github.com/bchetioui updated https://github.com/llvm/llvm-project/pull/75218

>From 10b59644b6e35dd310da8dd7ba1abafb89a0c5c9 Mon Sep 17 00:00:00 2001
From: Benjamin Chetioui <bchetioui at google.com>
Date: Tue, 12 Dec 2023 17:08:36 +0000
Subject: [PATCH] [MLIR] Flatten fused locations when merging constants.

[PR 74670](https://github.com/llvm/llvm-project/pull/74670) added
support for merging locations at constant folding time. We have
discovered that in some cases, the number of locations grows so big as
to cause a compilation process to OOM. In that case, many of the
locations end up appearing several times in nested fused locations.

We add here a helper that always flattens fused locations in order to
eliminate duplicates in the case of nested fused locations. We only
allow flattening nested fused locations when the inner fused location
has no metadata, or has the same metadata as the outer fused location.
---
 mlir/lib/Transforms/Utils/FoldUtils.cpp       | 40 ++++++++++++++++++-
 .../Transforms/canonicalize-debuginfo.mlir    |  8 ++--
 2 files changed, 43 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index dfc63ed6c4a542..ff0ef1f26c85bc 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -331,6 +331,42 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
   return newIt.first->second;
 }
 
+namespace {
+
+/// Helper that flattens nested fused locations to a single fused location.
+/// Fused locations nested under non-fused locations are not flattened, and
+/// calling this on non-fused locations is a no-op as a result.
+///
+/// Fused locations are only flattened into parent fused locations if the
+/// child fused location has no metadata, or if the metadata of the parent and
+/// child fused locations are the same---this to avoid breaking cases where
+/// metadata matter.
+Location FlattenFusedLocationRecursively(const Location loc) {
+  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
+    SetVector<Location> flattenedLocs;
+    Attribute metadata = fusedLoc.getMetadata();
+
+    for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
+      Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
+      auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
+
+      if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
+                                flattenedFusedLoc.getMetadata() == metadata)) {
+        ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
+        flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
+      } else {
+        flattenedLocs.insert(flattenedLoc);
+      }
+    }
+
+    return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
+                         fusedLoc.getMetadata());
+  }
+
+  return loc;
+}
+} // anonymous namespace
+
 void OperationFolder::appendFoldedLocation(Operation *retainedOp,
                                            Location foldedLocation) {
   // Append into existing fused location if it has the same tag.
@@ -344,7 +380,7 @@ void OperationFolder::appendFoldedLocation(Operation *retainedOp,
       locations.insert(foldedLocation);
       Location newFusedLoc = FusedLoc::get(
           retainedOp->getContext(), locations.takeVector(), existingMetadata);
-      retainedOp->setLoc(newFusedLoc);
+      retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
       return;
     }
   }
@@ -357,5 +393,5 @@ void OperationFolder::appendFoldedLocation(Operation *retainedOp,
   Location newFusedLoc =
       FusedLoc::get(retainedOp->getContext(),
                     {retainedOp->getLoc(), foldedLocation}, fusedLocationTag);
-  retainedOp->setLoc(newFusedLoc);
+  retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
 }
diff --git a/mlir/test/Transforms/canonicalize-debuginfo.mlir b/mlir/test/Transforms/canonicalize-debuginfo.mlir
index 034c9163a8059f..3cf98900a7c54a 100644
--- a/mlir/test/Transforms/canonicalize-debuginfo.mlir
+++ b/mlir/test/Transforms/canonicalize-debuginfo.mlir
@@ -1,19 +1,21 @@
 // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file -mlir-print-debuginfo | FileCheck %s
 
 // CHECK-LABEL: func @merge_constants
-func.func @merge_constants() -> (index, index, index, index) {
+func.func @merge_constants() -> (index, index, index, index, index) {
   // CHECK-NEXT: arith.constant 42 : index loc(#[[FusedLoc:.*]])
   %0 = arith.constant 42 : index loc("merge_constants":0:0)
   %1 = arith.constant 42 : index loc("merge_constants":1:0)
   %2 = arith.constant 42 : index loc("merge_constants":2:0)
   %3 = arith.constant 42 : index loc("merge_constants":2:0) // repeated loc
-  return %0, %1, %2, %3: index, index, index, index
+  %4 = arith.constant 42 : index loc(fused<"some_label">["merge_constants":3:0])
+  return %0, %1, %2, %3, %4 : index, index, index, index, index
 }
 
 // CHECK-DAG: #[[LocConst0:.*]] = loc("merge_constants":0:0)
 // CHECK-DAG: #[[LocConst1:.*]] = loc("merge_constants":1:0)
 // CHECK-DAG: #[[LocConst2:.*]] = loc("merge_constants":2:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]], #[[LocConst2]]])
+// CHECK-DAG: #[[LocConst3:.*]] = loc("merge_constants":3:0)
+// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]], #[[LocConst2]], #[[LocConst3]]])
 
 // -----
 



More information about the Mlir-commits mailing list