[Mlir-commits] [mlir] [mlir] Make overloads of SymbolTable::replaceAllSymbolUses consistent. (PR #68320)

Ingo Müller llvmlistbot at llvm.org
Fri Oct 6 08:37:57 PDT 2023


https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/68320

>From abe59c4fcfc65fe2f2ff924212d078737d8a1c10 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Thu, 5 Oct 2023 14:32:04 +0000
Subject: [PATCH 1/3] [mlir] Make overloads of
 SymbolTable::replaceAllSymbolUses consistent.

This function has several overloads that allow to specify the symbol
that should be renamed and the scope for that renaming in different
ways. The overloads were inconsistent in the following way (quoted
strings are `StringAttr`s, other variables are `Operation *`):

* `replaceAllSymbolUses(symbolOp, "new_symbol", scopeOp)` would traverse
  into the symbol table of `scopeOp`.
* `replaceAllSymbolUses("symbol", "new_symbol", scopeOp)` would *not*
  traverse into the symbol table of `scopeOp`.

The underlying behavior was spread over different places and is somewhat
hard to understand. The two overloads above mainly differed by what
`collectSymbolScopes` computed, which is itself overloaded. If `scopeOp`
is a top-level module, then the overload on
`(Operation *, Operation *)`, which is used in the first of the above
cases, computes a scope where the body region of the module is the
`limit`; however, the overload on `(StringAttr, Operation *)` computed
the module op itself as the `limit`. Later, `walkSymbolTable` would walk
the body of the module if it was given as a region but it would *not*
enter the regions of the module op because that op has a symbol table
(which was assumed to be a *different* scope).

The fix in this commit is change the behavior of `collectSymbolScopes`
such that the `(StringAttr, Operation *)` overload returns a scope for
each region in the `limit` argument.
---
 mlir/lib/IR/SymbolTable.cpp         | 12 ++++++++++--
 mlir/test/python/ir/symbol_table.py | 15 +++++++++++++--
 2 files changed, 23 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 2494cb7086f0d7d..b69f230f5108f62 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -655,12 +655,20 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
     scopes.back().limit = limit;
   return scopes;
 }
-template <typename IRUnit>
 static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
-                                                       IRUnit *limit) {
+                                                       Region *limit) {
   return {{SymbolRefAttr::get(symbol), limit}};
 }
 
+static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
+                                                       Operation *limit) {
+  SmallVector<SymbolScope, 1> scopes;
+  auto symbolRef = SymbolRefAttr::get(symbol);
+  for (auto &region : limit->getRegions())
+    scopes.push_back({symbolRef, &region});
+  return scopes;
+}
+
 /// Returns true if the given reference 'SubRef' is a sub reference of the
 /// reference 'ref', i.e. 'ref' is a further qualified reference.
 static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
index 3264cfcf9a10495..577721ab2111f55 100644
--- a/mlir/test/python/ir/symbol_table.py
+++ b/mlir/test/python/ir/symbol_table.py
@@ -106,9 +106,9 @@ def testSymbolTableRAUW():
       """
         )
         foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
+
+        # Do renaming just within `foo`.
         SymbolTable.set_symbol_name(bar, "bam")
-        # Note that module.operation counts as a "nested symbol table" which won't
-        # be traversed into, so it is necessary to traverse its children.
         SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
         # CHECK: call @bam()
         # CHECK: func private @bam
@@ -118,6 +118,17 @@ def testSymbolTableRAUW():
         print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
         print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
 
+        # Do renaming within the module.
+        SymbolTable.set_symbol_name(bar, "baz")
+        SymbolTable.replace_all_symbol_uses("bam", "baz", m.operation)
+        # CHECK: call @baz()
+        # CHECK: func private @baz
+        print(m)
+        # CHECK: Foo symbol: StringAttr("foo")
+        # CHECK: Bar symbol: StringAttr("baz")
+        print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
+        print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
+
 
 # CHECK-LABEL: testSymbolTableVisibility
 @run

>From 24a6403cce8f3345d424e468e731b41f1bb534ee Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 6 Oct 2023 09:04:27 +0000
Subject: [PATCH 2/3] Add unittest that tests all overloads.

---
 mlir/unittests/IR/CMakeLists.txt      |   1 +
 mlir/unittests/IR/SymbolTableTest.cpp | 107 ++++++++++++++++++++++++++
 2 files changed, 108 insertions(+)
 create mode 100644 mlir/unittests/IR/SymbolTableTest.cpp

diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 8a74a590962892f..6d05af193dfae01 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_unittest(MLIRIRTests
   OperationSupportTest.cpp
   PatternMatchTest.cpp
   ShapedTypeTest.cpp
+  SymbolTableTest.cpp
   TypeTest.cpp
   OpPropertiesTest.cpp
 
diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp
new file mode 100644
index 000000000000000..12bd281557ca3f4
--- /dev/null
+++ b/mlir/unittests/IR/SymbolTableTest.cpp
@@ -0,0 +1,107 @@
+//===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/IR/SymbolTable.h"
+#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Parser/Parser.h"
+
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+TEST(SymbolTableTest, ReplaceAllSymbolUses) {
+  MLIRContext context;
+  context.getOrLoadDialect<test::TestDialect>();
+
+  auto testReplaceAllSymbolUses = [&](auto replaceFn) {
+    const static llvm::StringLiteral input = R"MLIR(
+      module {
+        test.conversion_func_op private @foo() {
+          "test.conversion_call_op"() { callee=@bar } : () -> ()
+          "test.return"() : () -> ()
+        }
+        test.conversion_func_op private @bar()
+      }
+    )MLIR";
+
+    // Set up IR and find func ops.
+    OwningOpRef<Operation *> module = parseSourceString(input, &context);
+    SymbolTable symbolTable(module.get());
+    auto ops = module->getRegion(0).getBlocks().front().getOperations().begin();
+    auto fooOp = cast<FunctionOpInterface>(ops++);
+    auto barOp = cast<FunctionOpInterface>(ops++);
+    ASSERT_EQ(fooOp.getNameAttr(), "foo");
+    ASSERT_EQ(barOp.getNameAttr(), "bar");
+
+    // Call test function that does symbol replacement.
+    LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp);
+    ASSERT_TRUE(succeeded(res));
+    ASSERT_TRUE(succeeded(verify(module.get())));
+
+    // Check that it got renamed.
+    bool calleeFound = false;
+    fooOp->walk([&](CallOpInterface callOp) {
+      StringAttr callee = callOp.getCallableForCallee()
+                              .dyn_cast<SymbolRefAttr>()
+                              .getLeafReference();
+      EXPECT_EQ(callee, "baz");
+      calleeFound = true;
+    });
+    EXPECT_TRUE(calleeFound);
+  };
+
+  // Symbol as `Operation *`, rename within module.
+  testReplaceAllSymbolUses(
+      [&](auto symbolTable, auto module, auto fooOp, auto barOp) {
+        return symbolTable.replaceAllSymbolUses(
+            barOp, StringAttr::get(&context, "baz"), module);
+      });
+
+  // Symbol as `StringAttr`, rename within module.
+  testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
+                               auto barOp) {
+    return symbolTable.replaceAllSymbolUses(StringAttr::get(&context, "bar"),
+                                            StringAttr::get(&context, "baz"),
+                                            module);
+  });
+
+  // Symbol as `Operation *`, rename within module body.
+  testReplaceAllSymbolUses(
+      [&](auto symbolTable, auto module, auto fooOp, auto barOp) {
+        return symbolTable.replaceAllSymbolUses(
+            barOp, StringAttr::get(&context, "baz"), &module->getRegion(0));
+      });
+
+  // Symbol as `StringAttr`, rename within module body.
+  testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
+                               auto barOp) {
+    return symbolTable.replaceAllSymbolUses(StringAttr::get(&context, "bar"),
+                                            StringAttr::get(&context, "baz"),
+                                            &module->getRegion(0));
+  });
+
+  // Symbol as `Operation *`, rename within function.
+  testReplaceAllSymbolUses(
+      [&](auto symbolTable, auto module, auto fooOp, auto barOp) {
+        return symbolTable.replaceAllSymbolUses(
+            barOp, StringAttr::get(&context, "baz"), fooOp);
+      });
+
+  // Symbol as `StringAttr`, rename within function.
+  testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
+                               auto barOp) {
+    return symbolTable.replaceAllSymbolUses(StringAttr::get(&context, "bar"),
+                                            StringAttr::get(&context, "baz"),
+                                            fooOp);
+  });
+}
+
+} // namespace

>From bcd7d653ef3642650a4003a7fd8d89544d73ff9c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Fri, 6 Oct 2023 15:37:41 +0000
Subject: [PATCH 3/3] Address comments from @ftynse's review.

---
 mlir/unittests/IR/SymbolTableTest.cpp | 24 ++++++++++++++++--------
 1 file changed, 16 insertions(+), 8 deletions(-)

diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp
index 12bd281557ca3f4..3a1ed1acdac1c28 100644
--- a/mlir/unittests/IR/SymbolTableTest.cpp
+++ b/mlir/unittests/IR/SymbolTableTest.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 #include "mlir/IR/SymbolTable.h"
-#include "../../test/lib/Dialect/Test/TestDialect.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
@@ -16,12 +16,20 @@
 
 using namespace mlir;
 
+namespace test {
+void registerTestDialect(DialectRegistry &);
+} // namespace test
+
 namespace {
 TEST(SymbolTableTest, ReplaceAllSymbolUses) {
-  MLIRContext context;
-  context.getOrLoadDialect<test::TestDialect>();
+  DialectRegistry registry;
+  ::test::registerTestDialect(registry);
+
+  MLIRContext context(registry);
 
-  auto testReplaceAllSymbolUses = [&](auto replaceFn) {
+  using ReplaceFnType = llvm::function_ref<LogicalResult(
+      SymbolTable, ModuleOp, Operation *, Operation *)>;
+  auto testReplaceAllSymbolUses = [&](ReplaceFnType replaceFn) {
     const static llvm::StringLiteral input = R"MLIR(
       module {
         test.conversion_func_op private @foo() {
@@ -33,11 +41,11 @@ TEST(SymbolTableTest, ReplaceAllSymbolUses) {
     )MLIR";
 
     // Set up IR and find func ops.
-    OwningOpRef<Operation *> module = parseSourceString(input, &context);
+    OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(input, &context);
     SymbolTable symbolTable(module.get());
-    auto ops = module->getRegion(0).getBlocks().front().getOperations().begin();
-    auto fooOp = cast<FunctionOpInterface>(ops++);
-    auto barOp = cast<FunctionOpInterface>(ops++);
+    auto opIterator = module->getBody(0)->getOperations().begin();
+    auto fooOp = cast<FunctionOpInterface>(opIterator++);
+    auto barOp = cast<FunctionOpInterface>(opIterator++);
     ASSERT_EQ(fooOp.getNameAttr(), "foo");
     ASSERT_EQ(barOp.getNameAttr(), "bar");
 



More information about the Mlir-commits mailing list