[Mlir-commits] [mlir] [MLIR][NFC] Add fast path to fused loc flattening. (PR #75312)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 13 02:22:20 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Benjamin Chetioui (bchetioui)
<details>
<summary>Changes</summary>
This is a follow-up on [PR 75218](https://github.com/llvm/llvm-project/pull/75218) that avoids reconstructing a fused loc in the `FlattenFusedLocationRecursively` helper when there has been no change.
---
Full diff: https://github.com/llvm/llvm-project/pull/75312.diff
1 Files Affected:
- (modified) mlir/lib/Transforms/Utils/FoldUtils.cpp (+28-17)
``````````diff
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 056a681718e121..aa1e1ce01777db 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -340,28 +340,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
/// child fused locations are the same---this to avoid breaking cases where
/// metadata matter.
static Location FlattenFusedLocationRecursively(const Location loc) {
- if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
- SetVector<Location> flattenedLocs;
- Attribute metadata = fusedLoc.getMetadata();
-
- for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
- Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
- auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
-
- if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
- flattenedFusedLoc.getMetadata() == metadata)) {
- ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
- flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
- } else {
- flattenedLocs.insert(flattenedLoc);
+ auto fusedLoc = dyn_cast<FusedLoc>(loc);
+ if (!fusedLoc) return loc;
+
+ SetVector<Location> flattenedLocs;
+ Attribute metadata = fusedLoc.getMetadata();
+ ArrayRef<Location> unflattenedLocs = fusedLoc.getLocations();
+ bool hasAnyNestedLocChanged = false;
+
+ for (const Location &unflattenedLoc : unflattenedLocs) {
+ Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
+
+ auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
+ if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
+ flattenedFusedLoc.getMetadata() == metadata)) {
+ hasAnyNestedLocChanged = true;
+ ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
+ flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
+ } else {
+ if (flattenedLoc != unflattenedLoc) {
+ hasAnyNestedLocChanged = true;
}
+
+ flattenedLocs.insert(flattenedLoc);
}
+ }
- return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
- fusedLoc.getMetadata());
+ if (!hasAnyNestedLocChanged &&
+ unflattenedLocs.size() == flattenedLocs.size()) {
+ return loc;
}
- return loc;
+ return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
+ fusedLoc.getMetadata());
}
void OperationFolder::appendFoldedLocation(Operation *retainedOp,
``````````
</details>
https://github.com/llvm/llvm-project/pull/75312
More information about the Mlir-commits
mailing list