[Mlir-commits] [mlir] [MLIR] Flatten fused locations when merging constants. (PR #75218)

Benjamin Chetioui llvmlistbot at llvm.org
Tue Dec 12 09:14:35 PST 2023


https://github.com/bchetioui created https://github.com/llvm/llvm-project/pull/75218

[PR 74670](https://github.com/llvm/llvm-project/pull/74670) added support for merging locations at constant folding time. We have discovered that in some cases, the number of locations grows so big as to cause a compilation process to OOM. In that case, many of the locations end up appearing several times in nested fused locations.

We add here a helper that always flattens fused locations in order to eliminate duplicates in the case of nested fused locations.

>From 930030e1ff5d27c60d8fbb9840da4d8bf4eaca8e Mon Sep 17 00:00:00 2001
From: Benjamin Chetioui <bchetioui at google.com>
Date: Tue, 12 Dec 2023 17:08:36 +0000
Subject: [PATCH] [MLIR] Flatten fused locations when merging constants.

[PR 74670](https://github.com/llvm/llvm-project/pull/74670) added
support for merging locations at constant folding time. We have
discovered that in some cases, the number of locations grows so big as
to cause a compilation process to OOM. In that case, many of the
locations end up appearing several times in nested fused locations.

We add here a helper that always flattens fused locations in order to
eliminate duplicates in the case of nested fused locations.
---
 mlir/lib/Transforms/Utils/FoldUtils.cpp       | 32 +++++++++++++++++--
 .../Transforms/canonicalize-debuginfo.mlir    |  8 +++--
 2 files changed, 35 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index dfc63ed6c4a542..f414cca38f36ff 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -331,6 +331,34 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
   return newIt.first->second;
 }
 
+namespace {
+
+/// Helper that flattens nested fused locations to a single fused location.
+/// Fused locations nested under non-fused locations are not flattened, and
+/// calling this on non-fused locations is a no-op as a result. The metadata
+/// of the outer fused location is retained in the result.
+Location FlattenFusedLocationRecursively(const Location loc) {
+  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
+    SetVector<Location> flattenedLocs;
+
+    for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
+      Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
+      if (auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc)) {
+        ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
+        flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
+      } else {
+        flattenedLocs.insert(flattenedLoc);
+      }
+    }
+
+    return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
+                         fusedLoc.getMetadata());
+  }
+
+  return loc;
+}
+}  // anonymous namespace
+
 void OperationFolder::appendFoldedLocation(Operation *retainedOp,
                                            Location foldedLocation) {
   // Append into existing fused location if it has the same tag.
@@ -344,7 +372,7 @@ void OperationFolder::appendFoldedLocation(Operation *retainedOp,
       locations.insert(foldedLocation);
       Location newFusedLoc = FusedLoc::get(
           retainedOp->getContext(), locations.takeVector(), existingMetadata);
-      retainedOp->setLoc(newFusedLoc);
+      retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
       return;
     }
   }
@@ -357,5 +385,5 @@ void OperationFolder::appendFoldedLocation(Operation *retainedOp,
   Location newFusedLoc =
       FusedLoc::get(retainedOp->getContext(),
                     {retainedOp->getLoc(), foldedLocation}, fusedLocationTag);
-  retainedOp->setLoc(newFusedLoc);
+  retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
 }
diff --git a/mlir/test/Transforms/canonicalize-debuginfo.mlir b/mlir/test/Transforms/canonicalize-debuginfo.mlir
index 034c9163a8059f..3cf98900a7c54a 100644
--- a/mlir/test/Transforms/canonicalize-debuginfo.mlir
+++ b/mlir/test/Transforms/canonicalize-debuginfo.mlir
@@ -1,19 +1,21 @@
 // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file -mlir-print-debuginfo | FileCheck %s
 
 // CHECK-LABEL: func @merge_constants
-func.func @merge_constants() -> (index, index, index, index) {
+func.func @merge_constants() -> (index, index, index, index, index) {
   // CHECK-NEXT: arith.constant 42 : index loc(#[[FusedLoc:.*]])
   %0 = arith.constant 42 : index loc("merge_constants":0:0)
   %1 = arith.constant 42 : index loc("merge_constants":1:0)
   %2 = arith.constant 42 : index loc("merge_constants":2:0)
   %3 = arith.constant 42 : index loc("merge_constants":2:0) // repeated loc
-  return %0, %1, %2, %3: index, index, index, index
+  %4 = arith.constant 42 : index loc(fused<"some_label">["merge_constants":3:0])
+  return %0, %1, %2, %3, %4 : index, index, index, index, index
 }
 
 // CHECK-DAG: #[[LocConst0:.*]] = loc("merge_constants":0:0)
 // CHECK-DAG: #[[LocConst1:.*]] = loc("merge_constants":1:0)
 // CHECK-DAG: #[[LocConst2:.*]] = loc("merge_constants":2:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]], #[[LocConst2]]])
+// CHECK-DAG: #[[LocConst3:.*]] = loc("merge_constants":3:0)
+// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]], #[[LocConst2]], #[[LocConst3]]])
 
 // -----
 



More information about the Mlir-commits mailing list