[Mlir-commits] [mlir] [MLIR][NFC] Add fast path to fused loc flattening. (PR #75312)
Benjamin Chetioui
llvmlistbot at llvm.org
Wed Dec 13 02:21:54 PST 2023
https://github.com/bchetioui created https://github.com/llvm/llvm-project/pull/75312
This is a follow-up on [PR 75218](https://github.com/llvm/llvm-project/pull/75218) that avoids reconstructing a fused loc in the `FlattenFusedLocationRecursively` helper when there has been no change.
>From a980bd57a9989268293c541c825fd12ffd949173 Mon Sep 17 00:00:00 2001
From: Benjamin Chetioui <bchetioui at google.com>
Date: Wed, 13 Dec 2023 10:17:25 +0000
Subject: [PATCH] [MLIR][NFC] Add fast path to fused loc flattening.
This is a follow-up on
[PR 75218](https://github.com/llvm/llvm-project/pull/75218) that avoids
reconstructing a fused loc in the `FlattenFusedLocationRecursively`
helper when there has been no change.
---
mlir/lib/Transforms/Utils/FoldUtils.cpp | 45 +++++++++++++++----------
1 file changed, 28 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 056a681718e121..aa1e1ce01777db 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -340,28 +340,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
/// child fused locations are the same---this to avoid breaking cases where
/// metadata matter.
static Location FlattenFusedLocationRecursively(const Location loc) {
- if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
- SetVector<Location> flattenedLocs;
- Attribute metadata = fusedLoc.getMetadata();
-
- for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
- Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
- auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
-
- if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
- flattenedFusedLoc.getMetadata() == metadata)) {
- ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
- flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
- } else {
- flattenedLocs.insert(flattenedLoc);
+ auto fusedLoc = dyn_cast<FusedLoc>(loc);
+ if (!fusedLoc) return loc;
+
+ SetVector<Location> flattenedLocs;
+ Attribute metadata = fusedLoc.getMetadata();
+ ArrayRef<Location> unflattenedLocs = fusedLoc.getLocations();
+ bool hasAnyNestedLocChanged = false;
+
+ for (const Location &unflattenedLoc : unflattenedLocs) {
+ Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
+
+ auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
+ if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
+ flattenedFusedLoc.getMetadata() == metadata)) {
+ hasAnyNestedLocChanged = true;
+ ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
+ flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
+ } else {
+ if (flattenedLoc != unflattenedLoc) {
+ hasAnyNestedLocChanged = true;
}
+
+ flattenedLocs.insert(flattenedLoc);
}
+ }
- return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
- fusedLoc.getMetadata());
+ if (!hasAnyNestedLocChanged &&
+ unflattenedLocs.size() == flattenedLocs.size()) {
+ return loc;
}
- return loc;
+ return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
+ fusedLoc.getMetadata());
}
void OperationFolder::appendFoldedLocation(Operation *retainedOp,
More information about the Mlir-commits
mailing list