[Mlir-commits] [mlir] [MLIR] Flatten fused locations when merging constants. (PR #75218)

Benjamin Chetioui llvmlistbot at llvm.org
Tue Dec 12 13:00:15 PST 2023


https://github.com/bchetioui updated https://github.com/llvm/llvm-project/pull/75218

>From c78733fe1fef5928f0237e4b361cce577db2f742 Mon Sep 17 00:00:00 2001
From: Benjamin Chetioui <bchetioui at google.com>
Date: Tue, 12 Dec 2023 17:08:36 +0000
Subject: [PATCH] [MLIR] Flatten fused locations when merging constants.

[PR 74670](https://github.com/llvm/llvm-project/pull/74670) added
support for merging locations at constant folding time. We have
discovered that in some cases, the number of locations grows so big as
to cause a compilation process to OOM. In that case, many of the
locations end up appearing several times in nested fused locations.

We add here a helper that always flattens fused locations in order to
eliminate duplicates in the case of nested fused locations. We only
allow flattening nested fused locations when the inner fused location
has no metadata, or has the same metadata as the outer fused location.
---
 mlir/lib/Transforms/Utils/FoldUtils.cpp       | 37 ++++++++++++++++++-
 .../Transforms/canonicalize-debuginfo.mlir    | 13 +++++--
 2 files changed, 45 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index dfc63ed6c4a542..056a681718e121 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -331,6 +331,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
   return newIt.first->second;
 }
 
+/// Helper that flattens nested fused locations to a single fused location.
+/// Fused locations nested under non-fused locations are not flattened, and
+/// calling this on non-fused locations is a no-op as a result.
+///
+/// Fused locations are only flattened into parent fused locations if the
+/// child fused location has no metadata, or if the metadata of the parent and
+/// child fused locations are the same---this to avoid breaking cases where
+/// metadata matter.
+static Location FlattenFusedLocationRecursively(const Location loc) {
+  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
+    SetVector<Location> flattenedLocs;
+    Attribute metadata = fusedLoc.getMetadata();
+
+    for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
+      Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
+      auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
+
+      if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
+                                flattenedFusedLoc.getMetadata() == metadata)) {
+        ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
+        flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
+      } else {
+        flattenedLocs.insert(flattenedLoc);
+      }
+    }
+
+    return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
+                         fusedLoc.getMetadata());
+  }
+
+  return loc;
+}
+
 void OperationFolder::appendFoldedLocation(Operation *retainedOp,
                                            Location foldedLocation) {
   // Append into existing fused location if it has the same tag.
@@ -344,7 +377,7 @@ void OperationFolder::appendFoldedLocation(Operation *retainedOp,
       locations.insert(foldedLocation);
       Location newFusedLoc = FusedLoc::get(
           retainedOp->getContext(), locations.takeVector(), existingMetadata);
-      retainedOp->setLoc(newFusedLoc);
+      retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
       return;
     }
   }
@@ -357,5 +390,5 @@ void OperationFolder::appendFoldedLocation(Operation *retainedOp,
   Location newFusedLoc =
       FusedLoc::get(retainedOp->getContext(),
                     {retainedOp->getLoc(), foldedLocation}, fusedLocationTag);
-  retainedOp->setLoc(newFusedLoc);
+  retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
 }
diff --git a/mlir/test/Transforms/canonicalize-debuginfo.mlir b/mlir/test/Transforms/canonicalize-debuginfo.mlir
index 034c9163a8059f..217cc29c0095e2 100644
--- a/mlir/test/Transforms/canonicalize-debuginfo.mlir
+++ b/mlir/test/Transforms/canonicalize-debuginfo.mlir
@@ -1,19 +1,26 @@
 // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file -mlir-print-debuginfo | FileCheck %s
 
 // CHECK-LABEL: func @merge_constants
-func.func @merge_constants() -> (index, index, index, index) {
+func.func @merge_constants() -> (index, index, index, index, index, index, index) {
   // CHECK-NEXT: arith.constant 42 : index loc(#[[FusedLoc:.*]])
   %0 = arith.constant 42 : index loc("merge_constants":0:0)
   %1 = arith.constant 42 : index loc("merge_constants":1:0)
   %2 = arith.constant 42 : index loc("merge_constants":2:0)
   %3 = arith.constant 42 : index loc("merge_constants":2:0) // repeated loc
-  return %0, %1, %2, %3: index, index, index, index
+  %4 = arith.constant 43 : index loc(fused<"some_label">["merge_constants":3:0])
+  %5 = arith.constant 43 : index loc(fused<"some_label">["merge_constants":3:0])
+  %6 = arith.constant 43 : index loc(fused<"some_other_label">["merge_constants":3:0])
+  return %0, %1, %2, %3, %4, %5, %6 : index, index, index, index, index, index, index
 }
 
 // CHECK-DAG: #[[LocConst0:.*]] = loc("merge_constants":0:0)
 // CHECK-DAG: #[[LocConst1:.*]] = loc("merge_constants":1:0)
 // CHECK-DAG: #[[LocConst2:.*]] = loc("merge_constants":2:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]], #[[LocConst2]]])
+// CHECK-DAG: #[[LocConst3:.*]] = loc("merge_constants":3:0)
+// CHECK-DAG: #[[FusedLoc_CSE_1:.*]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]], #[[LocConst2]]])
+// CHECK-DAG: #[[FusedLoc_Some_Label:.*]] = loc(fused<"some_label">[#[[LocConst3]]])
+// CHECK-DAG: #[[FusedLoc_Some_Other_Label:.*]] = loc(fused<"some_other_label">[#[[LocConst3]]])
+// CHECK-DAG: #[[FusedLoc_CSE_2:.*]] = loc(fused<"CSE">[#[[FusedLoc_Some_Label]], #[[FusedLoc_Some_Other_Label]]])
 
 // -----
 



More information about the Mlir-commits mailing list