[Mlir-commits] [mlir] [MLIR] Fuse parent region location when hoisting constants (PR #75258)

Billy Zhu llvmlistbot at llvm.org
Tue Dec 12 16:11:21 PST 2023


https://github.com/zyx-billy updated https://github.com/llvm/llvm-project/pull/75258

>From 3bd3ea04f7b89b5122839cbc2bc2f9a36983c496 Mon Sep 17 00:00:00 2001
From: Billy Zhu <billyzhu at modular.com>
Date: Tue, 12 Dec 2023 14:57:26 -0800
Subject: [PATCH] fold parent location when hoisting

---
 mlir/include/mlir/Transforms/FoldUtils.h      | 10 +++--
 mlir/lib/Transforms/Utils/FoldUtils.cpp       | 40 +++++++++++--------
 .../Transforms/canonicalize-debuginfo.mlir    | 27 ++++++++++++-
 .../Transforms/constant-fold-debuginfo.mlir   | 32 +++++++++++++--
 4 files changed, 83 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index 28fa18cf942de..999b4774eae99 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -96,10 +96,12 @@ class OperationFolder {
                                     Dialect *dialect, Attribute value,
                                     Type type, Location loc);
 
-  // Fuse `foldedLocation` into the Location of `retainedOp`. This will result
-  // in `retainedOp` having a FusedLoc with `fusedLocationTag` to help trace the
-  // source of the fusion. If `retainedOp` already had a FusedLoc with the same
-  // tag, `foldedLocation` will simply be appended to it.
+  // Fuse `foldedLocation` into `originalLocation`. This will result in a
+  // FusedLoc with `fusedLocationTag` to help trace the source of the fusion.
+  // If `originalLocation` already had a FusedLoc with the same tag,
+  // `foldedLocation` will simply be appended to it.
+  Location getFusedLocation(Location originalLocation, Location foldedLocation);
+  // Update the location of `retainedOp` by applying `getFusedLocation`.
   void appendFoldedLocation(Operation *retainedOp, Location foldedLocation);
 
   /// Tag for annotating fused locations as a result of merging constants.
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 056a681718e12..b2afe08430fc4 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -152,8 +152,10 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
   // anything. Otherwise, we move the constant to the insertion block.
   Block *insertBlock = &insertRegion->front();
   if (opBlock != insertBlock || (&insertBlock->front() != op &&
-                                 !isFolderOwnedConstant(op->getPrevNode())))
+                                 !isFolderOwnedConstant(op->getPrevNode()))) {
     op->moveBefore(&insertBlock->front());
+    appendFoldedLocation(op, insertBlock->getParent()->getLoc());
+  }
 
   folderConstOp = op;
   referencedDialects[op].push_back(op->getDialect());
@@ -237,6 +239,7 @@ OperationFolder::processFoldResults(Operation *op,
   auto *insertRegion = getInsertionRegion(interfaces, op->getBlock());
   auto &entry = insertRegion->front();
   rewriter.setInsertionPoint(&entry, entry.begin());
+  Location loc = getFusedLocation(op->getLoc(), insertRegion->getLoc());
 
   // Get the constant map for the insertion region of this operation.
   auto &uniquedConstants = foldScopes[insertRegion];
@@ -259,8 +262,8 @@ OperationFolder::processFoldResults(Operation *op,
     // Check to see if there is a canonicalized version of this constant.
     auto res = op->getResult(i);
     Attribute attrRepl = foldResults[i].get<Attribute>();
-    if (auto *constOp = tryGetOrCreateConstant(
-            uniquedConstants, dialect, attrRepl, res.getType(), op->getLoc())) {
+    if (auto *constOp = tryGetOrCreateConstant(uniquedConstants, dialect,
+                                               attrRepl, res.getType(), loc)) {
       // Ensure that this constant dominates the operation we are replacing it
       // with. This may not automatically happen if the operation being folded
       // was inserted before the constant within the insertion block.
@@ -364,31 +367,36 @@ static Location FlattenFusedLocationRecursively(const Location loc) {
   return loc;
 }
 
-void OperationFolder::appendFoldedLocation(Operation *retainedOp,
+Location OperationFolder::getFusedLocation(Location originalLocation,
                                            Location foldedLocation) {
+  // If they're already equal, no need to fuse.
+  if (originalLocation == foldedLocation)
+    return originalLocation;
+
   // Append into existing fused location if it has the same tag.
   if (auto existingFusedLoc =
-          dyn_cast<FusedLocWith<StringAttr>>(retainedOp->getLoc())) {
+          dyn_cast<FusedLocWith<StringAttr>>(originalLocation)) {
     StringAttr existingMetadata = existingFusedLoc.getMetadata();
     if (existingMetadata == fusedLocationTag) {
       ArrayRef<Location> existingLocations = existingFusedLoc.getLocations();
       SetVector<Location> locations(existingLocations.begin(),
                                     existingLocations.end());
       locations.insert(foldedLocation);
-      Location newFusedLoc = FusedLoc::get(
-          retainedOp->getContext(), locations.takeVector(), existingMetadata);
-      retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
-      return;
+      Location newFusedLoc =
+          FusedLoc::get(originalLocation->getContext(), locations.takeVector(),
+                        existingMetadata);
+      return FlattenFusedLocationRecursively(newFusedLoc);
     }
   }
 
   // Create a new fusedloc with retainedOp's loc and foldedLocation.
-  // If they're already equal, no need to fuse.
-  if (retainedOp->getLoc() == foldedLocation)
-    return;
-
   Location newFusedLoc =
-      FusedLoc::get(retainedOp->getContext(),
-                    {retainedOp->getLoc(), foldedLocation}, fusedLocationTag);
-  retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
+      FusedLoc::get(originalLocation->getContext(),
+                    {originalLocation, foldedLocation}, fusedLocationTag);
+  return FlattenFusedLocationRecursively(newFusedLoc);
+}
+
+void OperationFolder::appendFoldedLocation(Operation *retainedOp,
+                                           Location foldedLocation) {
+  retainedOp->setLoc(getFusedLocation(retainedOp->getLoc(), foldedLocation));
 }
diff --git a/mlir/test/Transforms/canonicalize-debuginfo.mlir b/mlir/test/Transforms/canonicalize-debuginfo.mlir
index 217cc29c0095e..76cfe4dfff56f 100644
--- a/mlir/test/Transforms/canonicalize-debuginfo.mlir
+++ b/mlir/test/Transforms/canonicalize-debuginfo.mlir
@@ -34,8 +34,31 @@ func.func @hoist_constant(%arg0: memref<8xi32>) {
     memref.store %1, %arg0[%arg1] : memref<8xi32>
   }
   return
-}
+// CHECK: return
+// CHECK-NEXT: } loc(#[[LocFunc:.*]])
+} loc("hoist_constant":2:0)
 
 // CHECK-DAG: #[[LocConst0:.*]] = loc("hoist_constant":0:0)
 // CHECK-DAG: #[[LocConst1:.*]] = loc("hoist_constant":1:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]]])
+// CHECK-DAG: #[[LocFunc]] = loc("hoist_constant":2:0)
+// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocFunc]], #[[LocConst1]]])
+
+// -----
+
+// CHECK-LABEL: func @rehoist_constant
+func.func @rehoist_constant(%arg0: memref<8xi32>) -> i32 {
+  // CHECK-NEXT: arith.constant 42 : i32 loc(#[[FusedLoc:.*]])
+  %0 = arith.constant 42 : i32 loc("rehoist_constant":0:0)
+  affine.for %arg1 = 0 to 8 {
+    %1 = arith.constant 42 : i32 loc("rehoist_constant":1:0)
+    memref.store %1, %arg0[%arg1] : memref<8xi32>
+  }
+  return %0 : i32
+// CHECK: return
+// CHECK-NEXT: } loc(#[[LocFunc:.*]])
+} loc("rehoist_constant":2:0)
+
+// CHECK-DAG: #[[LocConst0:.*]] = loc("rehoist_constant":0:0)
+// CHECK-DAG: #[[LocConst1:.*]] = loc("rehoist_constant":1:0)
+// CHECK-DAG: #[[LocFunc]] = loc("rehoist_constant":2:0)
+// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocFunc]], #[[LocConst1]]])
diff --git a/mlir/test/Transforms/constant-fold-debuginfo.mlir b/mlir/test/Transforms/constant-fold-debuginfo.mlir
index 79a25f860a484..7e5ffc3c4d5b1 100644
--- a/mlir/test/Transforms/constant-fold-debuginfo.mlir
+++ b/mlir/test/Transforms/constant-fold-debuginfo.mlir
@@ -11,11 +11,13 @@ func.func @fold_and_merge() -> (i32, i32) {
   %3 = arith.constant 6 : i32 loc("fold_and_merge":1:0)
 
   return %2, %3: i32, i32
-}
+// CHECK: } loc(#[[LocFunc:.*]])
+} loc("fold_and_merge":2:0)
 
 // CHECK-DAG: #[[LocConst0:.*]] = loc("fold_and_merge":0:0)
 // CHECK-DAG: #[[LocConst1:.*]] = loc("fold_and_merge":1:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocConst0]]])
+// CHECK-DAG: #[[LocFunc]] = loc("fold_and_merge":2:0)
+// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocFunc]], #[[LocConst0]]])
 
 // -----
 
@@ -27,8 +29,30 @@ func.func @materialize_different_dialect() -> (f32, f32) {
   %2 = arith.constant 1.0 : f32 loc("materialize_different_dialect":1:0)
 
   return %1, %2: f32, f32
-}
+// CHECK: } loc(#[[LocFunc:.*]])
+} loc("materialize_different_dialect":2:0)
 
 // CHECK-DAG: #[[LocConst0:.*]] = loc("materialize_different_dialect":0:0)
 // CHECK-DAG: #[[LocConst1:.*]] = loc("materialize_different_dialect":1:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocConst0]]])
+// CHECK-DAG: #[[LocFunc]] = loc("materialize_different_dialect":2:0)
+// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocFunc]], #[[LocConst0]]])
+
+// -----
+
+// CHECK-LABEL: func @materialize_in_front
+func.func @materialize_in_front(%arg0: memref<8xi32>) {
+  // CHECK-NEXT: arith.constant 6 : i32 loc(#[[FusedLoc:.*]])
+  affine.for %arg1 = 0 to 8 {
+    %1 = arith.constant 1 : i32
+    %2 = arith.constant 5 : i32
+    %3 = arith.addi %1, %2 : i32 loc("materialize_in_front":0:0)
+    memref.store %3, %arg0[%arg1] : memref<8xi32>
+  }
+  // CHECK: return
+  return
+// CHECK-NEXT: } loc(#[[LocFunc:.*]])
+} loc("materialize_in_front":1:0)
+
+// CHECK-DAG: #[[LocConst0:.*]] = loc("materialize_in_front":0:0)
+// CHECK-DAG: #[[LocFunc]] = loc("materialize_in_front":1:0)
+// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocFunc]]])



More information about the Mlir-commits mailing list