[Mlir-commits] [mlir] 6fe3cd5 - [MLIR][NFC] Add fast path to fused loc flattening. (#75312)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 13 03:40:45 PST 2023
Author: Benjamin Chetioui
Date: 2023-12-13T12:40:41+01:00
New Revision: 6fe3cd54670cae52dae92a38a6d28f450fe8f321
URL: https://github.com/llvm/llvm-project/commit/6fe3cd54670cae52dae92a38a6d28f450fe8f321
DIFF: https://github.com/llvm/llvm-project/commit/6fe3cd54670cae52dae92a38a6d28f450fe8f321.diff
LOG: [MLIR][NFC] Add fast path to fused loc flattening. (#75312)
This is a follow-up on [PR
75218](https://github.com/llvm/llvm-project/pull/75218) that avoids
reconstructing a fused loc in the `FlattenFusedLocationRecursively`
helper when there has been no change.
Added:
Modified:
mlir/lib/Transforms/Utils/FoldUtils.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 056a681718e121..136c4d2216b865 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -340,28 +340,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
/// child fused locations are the same---this to avoid breaking cases where
/// metadata matter.
static Location FlattenFusedLocationRecursively(const Location loc) {
- if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
- SetVector<Location> flattenedLocs;
- Attribute metadata = fusedLoc.getMetadata();
-
- for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
- Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
- auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
-
- if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
- flattenedFusedLoc.getMetadata() == metadata)) {
- ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
- flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
- } else {
- flattenedLocs.insert(flattenedLoc);
- }
+ auto fusedLoc = dyn_cast<FusedLoc>(loc);
+ if (!fusedLoc)
+ return loc;
+
+ SetVector<Location> flattenedLocs;
+ Attribute metadata = fusedLoc.getMetadata();
+ ArrayRef<Location> unflattenedLocs = fusedLoc.getLocations();
+ bool hasAnyNestedLocChanged = false;
+
+ for (const Location &unflattenedLoc : unflattenedLocs) {
+ Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
+
+ auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
+ if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
+ flattenedFusedLoc.getMetadata() == metadata)) {
+ hasAnyNestedLocChanged = true;
+ ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
+ flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
+ } else {
+ if (flattenedLoc != unflattenedLoc)
+ hasAnyNestedLocChanged = true;
+
+ flattenedLocs.insert(flattenedLoc);
}
+ }
- return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
- fusedLoc.getMetadata());
+ if (!hasAnyNestedLocChanged &&
+ unflattenedLocs.size() == flattenedLocs.size()) {
+ return loc;
}
- return loc;
+ return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
+ fusedLoc.getMetadata());
}
void OperationFolder::appendFoldedLocation(Operation *retainedOp,
More information about the Mlir-commits
mailing list