[Mlir-commits] [mlir] [MLIR] Fuse parent region location when hoisting constants (PR #75258)
Billy Zhu
llvmlistbot at llvm.org
Tue Dec 12 16:23:30 PST 2023
https://github.com/zyx-billy updated https://github.com/llvm/llvm-project/pull/75258
>From 8abe24bc6526c27c977339bc63f69d5f29c435b1 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 12 Dec 2023 14:57:26 -0800
Subject: [PATCH] fold parent location when hoisting
---
mlir/include/mlir/Transforms/FoldUtils.h | 10 +++--
mlir/lib/Transforms/Utils/FoldUtils.cpp | 40 +++++++++++--------
.../Transforms/canonicalize-debuginfo.mlir | 26 +++++++++++-
.../Transforms/constant-fold-debuginfo.mlir | 32 +++++++++++++--
4 files changed, 82 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index 28fa18cf942de4..999b4774eae996 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -96,10 +96,12 @@ class OperationFolder {
Dialect *dialect, Attribute value,
Type type, Location loc);
- // Fuse `foldedLocation` into the Location of `retainedOp`. This will result
- // in `retainedOp` having a FusedLoc with `fusedLocationTag` to help trace the
- // source of the fusion. If `retainedOp` already had a FusedLoc with the same
- // tag, `foldedLocation` will simply be appended to it.
+ // Fuse `foldedLocation` into `originalLocation`. This will result in a
+ // FusedLoc with `fusedLocationTag` to help trace the source of the fusion.
+ // If `originalLocation` already had a FusedLoc with the same tag,
+ // `foldedLocation` will simply be appended to it.
+ Location getFusedLocation(Location originalLocation, Location foldedLocation);
+ // Update the location of `retainedOp` by applying `getFusedLocation`.
void appendFoldedLocation(Operation *retainedOp, Location foldedLocation);
/// Tag for annotating fused locations as a result of merging constants.
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 056a681718e121..b2afe08430fc4d 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -152,8 +152,10 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
// anything. Otherwise, we move the constant to the insertion block.
Block *insertBlock = &insertRegion->front();
if (opBlock != insertBlock || (&insertBlock->front() != op &&
- !isFolderOwnedConstant(op->getPrevNode())))
+ !isFolderOwnedConstant(op->getPrevNode()))) {
op->moveBefore(&insertBlock->front());
+ appendFoldedLocation(op, insertBlock->getParent()->getLoc());
+ }
folderConstOp = op;
referencedDialects[op].push_back(op->getDialect());
@@ -237,6 +239,7 @@ OperationFolder::processFoldResults(Operation *op,
auto *insertRegion = getInsertionRegion(interfaces, op->getBlock());
auto &entry = insertRegion->front();
rewriter.setInsertionPoint(&entry, entry.begin());
+ Location loc = getFusedLocation(op->getLoc(), insertRegion->getLoc());
// Get the constant map for the insertion region of this operation.
auto &uniquedConstants = foldScopes[insertRegion];
@@ -259,8 +262,8 @@ OperationFolder::processFoldResults(Operation *op,
// Check to see if there is a canonicalized version of this constant.
auto res = op->getResult(i);
Attribute attrRepl = foldResults[i].get<Attribute>();
- if (auto *constOp = tryGetOrCreateConstant(
- uniquedConstants, dialect, attrRepl, res.getType(), op->getLoc())) {
+ if (auto *constOp = tryGetOrCreateConstant(uniquedConstants, dialect,
+ attrRepl, res.getType(), loc)) {
// Ensure that this constant dominates the operation we are replacing it
// with. This may not automatically happen if the operation being folded
// was inserted before the constant within the insertion block.
@@ -364,31 +367,36 @@ static Location FlattenFusedLocationRecursively(const Location loc) {
return loc;
}
-void OperationFolder::appendFoldedLocation(Operation *retainedOp,
+Location OperationFolder::getFusedLocation(Location originalLocation,
Location foldedLocation) {
+ // If they're already equal, no need to fuse.
+ if (originalLocation == foldedLocation)
+ return originalLocation;
+
// Append into existing fused location if it has the same tag.
if (auto existingFusedLoc =
- dyn_cast<FusedLocWith<StringAttr>>(retainedOp->getLoc())) {
+ dyn_cast<FusedLocWith<StringAttr>>(originalLocation)) {
StringAttr existingMetadata = existingFusedLoc.getMetadata();
if (existingMetadata == fusedLocationTag) {
ArrayRef<Location> existingLocations = existingFusedLoc.getLocations();
SetVector<Location> locations(existingLocations.begin(),
existingLocations.end());
locations.insert(foldedLocation);
- Location newFusedLoc = FusedLoc::get(
- retainedOp->getContext(), locations.takeVector(), existingMetadata);
- retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
- return;
+ Location newFusedLoc =
+ FusedLoc::get(originalLocation->getContext(), locations.takeVector(),
+ existingMetadata);
+ return FlattenFusedLocationRecursively(newFusedLoc);
}
}
// Create a new fusedloc with retainedOp's loc and foldedLocation.
- // If they're already equal, no need to fuse.
- if (retainedOp->getLoc() == foldedLocation)
- return;
-
Location newFusedLoc =
- FusedLoc::get(retainedOp->getContext(),
- {retainedOp->getLoc(), foldedLocation}, fusedLocationTag);
- retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
+ FusedLoc::get(originalLocation->getContext(),
+ {originalLocation, foldedLocation}, fusedLocationTag);
+ return FlattenFusedLocationRecursively(newFusedLoc);
+}
+
+void OperationFolder::appendFoldedLocation(Operation *retainedOp,
+ Location foldedLocation) {
+ retainedOp->setLoc(getFusedLocation(retainedOp->getLoc(), foldedLocation));
}
diff --git a/mlir/test/Transforms/canonicalize-debuginfo.mlir b/mlir/test/Transforms/canonicalize-debuginfo.mlir
index 217cc29c0095e2..8b726eabb4c2e8 100644
--- a/mlir/test/Transforms/canonicalize-debuginfo.mlir
+++ b/mlir/test/Transforms/canonicalize-debuginfo.mlir
@@ -33,9 +33,31 @@ func.func @hoist_constant(%arg0: memref<8xi32>) {
memref.store %0, %arg0[%arg1] : memref<8xi32>
memref.store %1, %arg0[%arg1] : memref<8xi32>
}
+ // CHECK: return
return
-}
+// CHECK-NEXT: } loc(#[[LocFunc:.*]])
+} loc("hoist_constant":2:0)
// CHECK-DAG: #[[LocConst0:.*]] = loc("hoist_constant":0:0)
// CHECK-DAG: #[[LocConst1:.*]] = loc("hoist_constant":1:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]]])
+// CHECK-DAG: #[[LocFunc]] = loc("hoist_constant":2:0)
+// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocFunc]], #[[LocConst1]]])
+
+// -----
+
+// CHECK-LABEL: func @hoist_constant_simple
+func.func @hoist_constant_simple(%arg0: memref<8xi32>) -> i32 {
+ // CHECK-NEXT: arith.constant 88 : i32 loc(#[[FusedLoc:.*]])
+ %0 = arith.constant 42 : i32 loc("hoist_constant_simple":0:0)
+ %1 = arith.constant 0 : index
+ memref.store %0, %arg0[%1] : memref<8xi32>
+
+ %2 = arith.constant 88 : i32 loc("hoist_constant_simple":1:0)
+
+ return %2 : i32
+// CHECK: } loc(#[[LocFunc:.*]])
+} loc("hoist_constant_simple":2:0)
+
+// CHECK-DAG: #[[LocConst1:.*]] = loc("hoist_constant_simple":1:0)
+// CHECK-DAG: #[[LocFunc]] = loc("hoist_constant_simple":2:0)
+// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocFunc]]])
diff --git a/mlir/test/Transforms/constant-fold-debuginfo.mlir b/mlir/test/Transforms/constant-fold-debuginfo.mlir
index 79a25f860a4841..7e5ffc3c4d5b13 100644
--- a/mlir/test/Transforms/constant-fold-debuginfo.mlir
+++ b/mlir/test/Transforms/constant-fold-debuginfo.mlir
@@ -11,11 +11,13 @@ func.func @fold_and_merge() -> (i32, i32) {
%3 = arith.constant 6 : i32 loc("fold_and_merge":1:0)
return %2, %3: i32, i32
-}
+// CHECK: } loc(#[[LocFunc:.*]])
+} loc("fold_and_merge":2:0)
// CHECK-DAG: #[[LocConst0:.*]] = loc("fold_and_merge":0:0)
// CHECK-DAG: #[[LocConst1:.*]] = loc("fold_and_merge":1:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocConst0]]])
+// CHECK-DAG: #[[LocFunc]] = loc("fold_and_merge":2:0)
+// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocFunc]], #[[LocConst0]]])
// -----
@@ -27,8 +29,30 @@ func.func @materialize_different_dialect() -> (f32, f32) {
%2 = arith.constant 1.0 : f32 loc("materialize_different_dialect":1:0)
return %1, %2: f32, f32
-}
+// CHECK: } loc(#[[LocFunc:.*]])
+} loc("materialize_different_dialect":2:0)
// CHECK-DAG: #[[LocConst0:.*]] = loc("materialize_different_dialect":0:0)
// CHECK-DAG: #[[LocConst1:.*]] = loc("materialize_different_dialect":1:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocConst0]]])
+// CHECK-DAG: #[[LocFunc]] = loc("materialize_different_dialect":2:0)
+// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocFunc]], #[[LocConst0]]])
+
+// -----
+
+// CHECK-LABEL: func @materialize_in_front
+func.func @materialize_in_front(%arg0: memref<8xi32>) {
+ // CHECK-NEXT: arith.constant 6 : i32 loc(#[[FusedLoc:.*]])
+ affine.for %arg1 = 0 to 8 {
+ %1 = arith.constant 1 : i32
+ %2 = arith.constant 5 : i32
+ %3 = arith.addi %1, %2 : i32 loc("materialize_in_front":0:0)
+ memref.store %3, %arg0[%arg1] : memref<8xi32>
+ }
+ // CHECK: return
+ return
+// CHECK-NEXT: } loc(#[[LocFunc:.*]])
+} loc("materialize_in_front":1:0)
+
+// CHECK-DAG: #[[LocConst0:.*]] = loc("materialize_in_front":0:0)
+// CHECK-DAG: #[[LocFunc]] = loc("materialize_in_front":1:0)
+// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocFunc]]])
More information about the Mlir-commits
mailing list