[Mlir-commits] [mlir] [mlir] Extend tests of SymbolTable::replaceAllSymbolUses. (PR #68780)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 11 02:28:17 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Ingo Müller (ingomueller-net)
<details>
<summary>Changes</summary>
This is a follow-up commit for 479057887f@<!-- -->llvm/llvm-project (#<!-- -->68320) that adds more tests. In particular, the tests now check that the `limit` op itself is not traversed, i.e., symbols in attributes in of the `limit` op are not renamed.
---
Full diff: https://github.com/llvm/llvm-project/pull/68780.diff
1 Files Affected:
- (modified) mlir/unittests/IR/SymbolTableTest.cpp (+18-12)
``````````diff
diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp
index 5dcec749f0f4259..12f582874d7f549 100644
--- a/mlir/unittests/IR/SymbolTableTest.cpp
+++ b/mlir/unittests/IR/SymbolTableTest.cpp
@@ -28,12 +28,14 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
void SetUp() override {
::test::registerTestDialect(registry);
context = std::make_unique<MLIRContext>(registry);
+ builder = std::make_unique<OpBuilder>(context.get());
}
void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
// Set up IR and find func ops.
OwningOpRef<ModuleOp> module =
parseSourceString<ModuleOp>(kInput, context.get());
+ ASSERT_TRUE(module);
SymbolTable symbolTable(module.get());
auto opIterator = module->getBody(0)->getOperations().begin();
auto fooOp = cast<FunctionOpInterface>(opIterator++);
@@ -46,7 +48,7 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
ASSERT_TRUE(succeeded(res));
ASSERT_TRUE(succeeded(verify(module.get())));
- // Check that it got renamed.
+ // Check that callee of the call op got renamed.
bool calleeFound = false;
fooOp->walk([&](CallOpInterface callOp) {
StringAttr callee = callOp.getCallableForCallee()
@@ -56,13 +58,19 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
calleeFound = true;
});
EXPECT_TRUE(calleeFound);
+
+ // Check that module attribute did *not* get renamed.
+ auto moduleAttr = (*module)->getAttrOfType<FlatSymbolRefAttr>("test.attr");
+ ASSERT_TRUE(moduleAttr);
+ EXPECT_EQ(moduleAttr.getValue(), StringRef("bar"));
}
std::unique_ptr<MLIRContext> context;
+ std::unique_ptr<OpBuilder> builder;
private:
constexpr static llvm::StringLiteral kInput = R"MLIR(
- module {
+ module attributes { test.attr = @bar } {
test.conversion_func_op private @foo() {
"test.conversion_call_op"() { callee=@bar } : () -> ()
"test.return"() : () -> ()
@@ -81,7 +89,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
- barOp, StringAttr::get(context.get(), "baz"), module);
+ barOp, builder->getStringAttr("baz"), module);
});
}
@@ -90,8 +98,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
- StringAttr::get(context.get(), "bar"),
- StringAttr::get(context.get(), "baz"), module);
+ builder->getStringAttr("bar"), builder->getStringAttr("baz"), module);
});
}
@@ -100,7 +107,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
- barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
+ barOp, builder->getStringAttr("baz"), &module->getRegion(0));
});
}
@@ -108,9 +115,9 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
// Symbol as `StringAttr`, rename within module body.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
- return symbolTable.replaceAllSymbolUses(
- StringAttr::get(context.get(), "bar"),
- StringAttr::get(context.get(), "baz"), &module->getRegion(0));
+ return symbolTable.replaceAllSymbolUses(builder->getStringAttr("bar"),
+ builder->getStringAttr("baz"),
+ &module->getRegion(0));
});
}
@@ -119,7 +126,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
- barOp, StringAttr::get(context.get(), "baz"), fooOp);
+ barOp, builder->getStringAttr("baz"), fooOp);
});
}
@@ -128,8 +135,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
- StringAttr::get(context.get(), "bar"),
- StringAttr::get(context.get(), "baz"), fooOp);
+ builder->getStringAttr("bar"), builder->getStringAttr("baz"), fooOp);
});
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/68780
More information about the Mlir-commits
mailing list