[Mlir-commits] [mlir] [mlir] Extend tests of SymbolTable::replaceAllSymbolUses. (PR #68780)

Ingo Müller llvmlistbot at llvm.org
Wed Oct 11 02:27:03 PDT 2023


https://github.com/ingomueller-net created https://github.com/llvm/llvm-project/pull/68780

This is a follow-up commit for 479057887f at llvm/llvm-project (#68320) that adds more tests. In particular, the tests now check that the `limit` op itself is not traversed, i.e., symbols in attributes in of the `limit` op are not renamed.

>From da99b86d9668079a1eaf6515de67236dea06ce48 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Wed, 11 Oct 2023 09:23:37 +0000
Subject: [PATCH] [mlir] Extend tests of SymbolTable::replaceAllSymbolUses.

This is a follow-up commit for 479057887f at llvm/llvm-project (#68320)
that adds more tests. In particular, the tests now check that the
`limit` op itself is not traversed, i.e., symbols in attributes in of
the `limit` op are not renamed.
---
 mlir/unittests/IR/SymbolTableTest.cpp | 30 ++++++++++++++++-----------
 1 file changed, 18 insertions(+), 12 deletions(-)

diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp
index 5dcec749f0f4259..12f582874d7f549 100644
--- a/mlir/unittests/IR/SymbolTableTest.cpp
+++ b/mlir/unittests/IR/SymbolTableTest.cpp
@@ -28,12 +28,14 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
   void SetUp() override {
     ::test::registerTestDialect(registry);
     context = std::make_unique<MLIRContext>(registry);
+    builder = std::make_unique<OpBuilder>(context.get());
   }
 
   void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
     // Set up IR and find func ops.
     OwningOpRef<ModuleOp> module =
         parseSourceString<ModuleOp>(kInput, context.get());
+    ASSERT_TRUE(module);
     SymbolTable symbolTable(module.get());
     auto opIterator = module->getBody(0)->getOperations().begin();
     auto fooOp = cast<FunctionOpInterface>(opIterator++);
@@ -46,7 +48,7 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
     ASSERT_TRUE(succeeded(res));
     ASSERT_TRUE(succeeded(verify(module.get())));
 
-    // Check that it got renamed.
+    // Check that callee of the call op got renamed.
     bool calleeFound = false;
     fooOp->walk([&](CallOpInterface callOp) {
       StringAttr callee = callOp.getCallableForCallee()
@@ -56,13 +58,19 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
       calleeFound = true;
     });
     EXPECT_TRUE(calleeFound);
+
+    // Check that module attribute did *not* get renamed.
+    auto moduleAttr = (*module)->getAttrOfType<FlatSymbolRefAttr>("test.attr");
+    ASSERT_TRUE(moduleAttr);
+    EXPECT_EQ(moduleAttr.getValue(), StringRef("bar"));
   }
 
   std::unique_ptr<MLIRContext> context;
+  std::unique_ptr<OpBuilder> builder;
 
 private:
   constexpr static llvm::StringLiteral kInput = R"MLIR(
-      module {
+      module attributes { test.attr = @bar } {
         test.conversion_func_op private @foo() {
           "test.conversion_call_op"() { callee=@bar } : () -> ()
           "test.return"() : () -> ()
@@ -81,7 +89,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
                                auto barOp) -> LogicalResult {
     return symbolTable.replaceAllSymbolUses(
-        barOp, StringAttr::get(context.get(), "baz"), module);
+        barOp, builder->getStringAttr("baz"), module);
   });
 }
 
@@ -90,8 +98,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
                                auto barOp) -> LogicalResult {
     return symbolTable.replaceAllSymbolUses(
-        StringAttr::get(context.get(), "bar"),
-        StringAttr::get(context.get(), "baz"), module);
+        builder->getStringAttr("bar"), builder->getStringAttr("baz"), module);
   });
 }
 
@@ -100,7 +107,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
                                auto barOp) -> LogicalResult {
     return symbolTable.replaceAllSymbolUses(
-        barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
+        barOp, builder->getStringAttr("baz"), &module->getRegion(0));
   });
 }
 
@@ -108,9 +115,9 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
   // Symbol as `StringAttr`, rename within module body.
   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
                                auto barOp) -> LogicalResult {
-    return symbolTable.replaceAllSymbolUses(
-        StringAttr::get(context.get(), "bar"),
-        StringAttr::get(context.get(), "baz"), &module->getRegion(0));
+    return symbolTable.replaceAllSymbolUses(builder->getStringAttr("bar"),
+                                            builder->getStringAttr("baz"),
+                                            &module->getRegion(0));
   });
 }
 
@@ -119,7 +126,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
                                auto barOp) -> LogicalResult {
     return symbolTable.replaceAllSymbolUses(
-        barOp, StringAttr::get(context.get(), "baz"), fooOp);
+        barOp, builder->getStringAttr("baz"), fooOp);
   });
 }
 
@@ -128,8 +135,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
   testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
                                auto barOp) -> LogicalResult {
     return symbolTable.replaceAllSymbolUses(
-        StringAttr::get(context.get(), "bar"),
-        StringAttr::get(context.get(), "baz"), fooOp);
+        builder->getStringAttr("bar"), builder->getStringAttr("baz"), fooOp);
   });
 }
 



More information about the Mlir-commits mailing list