[Mlir-commits] [mlir] [mlir] Make overloads of SymbolTable::replaceAllSymbolUses consistent. (PR #68320)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 5 08:17:51 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

This function has several overloads that allow to specify the symbol that should be renamed and the scope for that renaming in different ways. The overloads were inconsistent in the following way (quoted strings are `StringAttr`s, other variables are `Operation *`):

* `replaceAllSymbolUses(symbolOp, "new_symbol", scopeOp)` would traverse into the symbol table of `scopeOp`.
* `replaceAllSymbolUses("symbol", "new_symbol", scopeOp)` would *not* traverse into the symbol table of `scopeOp`.

The underlying behavior was spread over different places and is somewhat hard to understand. The two overloads above mainly differed by what `collectSymbolScopes` computed, which is itself overloaded. If `scopeOp` is a top-level module, then the overload on
`(Operation *, Operation *)`, which is used in the first of the above cases, computes a scope where the body region of the module is the `limit`; however, the overload on `(StringAttr, Operation *)` computed the module op itself as the `limit`. Later, `walkSymbolTable` would walk the body of the module if it was given as a region but it would *not* enter the regions of the module op because that op has a symbol table (which was assumed to be a *different* scope).

The fix in this commit is change the behavior of `collectSymbolScopes` such that the `(StringAttr, Operation *)` overload returns a scope for each region in the `limit` argument.

---
Full diff: https://github.com/llvm/llvm-project/pull/68320.diff


2 Files Affected:

- (modified) mlir/lib/IR/SymbolTable.cpp (+10-2) 
- (modified) mlir/test/python/ir/symbol_table.py (+13-2) 


``````````diff
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 2494cb7086f0d7d..b69f230f5108f62 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -655,12 +655,20 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
     scopes.back().limit = limit;
   return scopes;
 }
-template <typename IRUnit>
 static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
-                                                       IRUnit *limit) {
+                                                       Region *limit) {
   return {{SymbolRefAttr::get(symbol), limit}};
 }
 
+static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
+                                                       Operation *limit) {
+  SmallVector<SymbolScope, 1> scopes;
+  auto symbolRef = SymbolRefAttr::get(symbol);
+  for (auto &region : limit->getRegions())
+    scopes.push_back({symbolRef, &region});
+  return scopes;
+}
+
 /// Returns true if the given reference 'SubRef' is a sub reference of the
 /// reference 'ref', i.e. 'ref' is a further qualified reference.
 static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
index 3264cfcf9a10495..577721ab2111f55 100644
--- a/mlir/test/python/ir/symbol_table.py
+++ b/mlir/test/python/ir/symbol_table.py
@@ -106,9 +106,9 @@ def testSymbolTableRAUW():
       """
         )
         foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
+
+        # Do renaming just within `foo`.
         SymbolTable.set_symbol_name(bar, "bam")
-        # Note that module.operation counts as a "nested symbol table" which won't
-        # be traversed into, so it is necessary to traverse its children.
         SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
         # CHECK: call @bam()
         # CHECK: func private @bam
@@ -118,6 +118,17 @@ def testSymbolTableRAUW():
         print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
         print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
 
+        # Do renaming within the module.
+        SymbolTable.set_symbol_name(bar, "baz")
+        SymbolTable.replace_all_symbol_uses("bam", "baz", m.operation)
+        # CHECK: call @baz()
+        # CHECK: func private @baz
+        print(m)
+        # CHECK: Foo symbol: StringAttr("foo")
+        # CHECK: Bar symbol: StringAttr("baz")
+        print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
+        print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
+
 
 # CHECK-LABEL: testSymbolTableVisibility
 @run

``````````

</details>


https://github.com/llvm/llvm-project/pull/68320


More information about the Mlir-commits mailing list