[Mlir-commits] [mlir] [mlir] Extend tests of SymbolTable::replaceAllSymbolUses. (PR #68780)
Ingo Müller
llvmlistbot at llvm.org
Wed Oct 11 02:27:03 PDT 2023
https://github.com/ingomueller-net created https://github.com/llvm/llvm-project/pull/68780
This is a follow-up commit for 479057887f at llvm/llvm-project (#68320) that adds more tests. In particular, the tests now check that the `limit` op itself is not traversed, i.e., symbols in attributes in of the `limit` op are not renamed.
>From da99b86d9668079a1eaf6515de67236dea06ce48 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Wed, 11 Oct 2023 09:23:37 +0000
Subject: [PATCH] [mlir] Extend tests of SymbolTable::replaceAllSymbolUses.
This is a follow-up commit for 479057887f at llvm/llvm-project (#68320)
that adds more tests. In particular, the tests now check that the
`limit` op itself is not traversed, i.e., symbols in attributes in of
the `limit` op are not renamed.
---
mlir/unittests/IR/SymbolTableTest.cpp | 30 ++++++++++++++++-----------
1 file changed, 18 insertions(+), 12 deletions(-)
diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp
index 5dcec749f0f4259..12f582874d7f549 100644
--- a/mlir/unittests/IR/SymbolTableTest.cpp
+++ b/mlir/unittests/IR/SymbolTableTest.cpp
@@ -28,12 +28,14 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
void SetUp() override {
::test::registerTestDialect(registry);
context = std::make_unique<MLIRContext>(registry);
+ builder = std::make_unique<OpBuilder>(context.get());
}
void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
// Set up IR and find func ops.
OwningOpRef<ModuleOp> module =
parseSourceString<ModuleOp>(kInput, context.get());
+ ASSERT_TRUE(module);
SymbolTable symbolTable(module.get());
auto opIterator = module->getBody(0)->getOperations().begin();
auto fooOp = cast<FunctionOpInterface>(opIterator++);
@@ -46,7 +48,7 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
ASSERT_TRUE(succeeded(res));
ASSERT_TRUE(succeeded(verify(module.get())));
- // Check that it got renamed.
+ // Check that callee of the call op got renamed.
bool calleeFound = false;
fooOp->walk([&](CallOpInterface callOp) {
StringAttr callee = callOp.getCallableForCallee()
@@ -56,13 +58,19 @@ class ReplaceAllSymbolUsesTest : public ::testing::Test {
calleeFound = true;
});
EXPECT_TRUE(calleeFound);
+
+ // Check that module attribute did *not* get renamed.
+ auto moduleAttr = (*module)->getAttrOfType<FlatSymbolRefAttr>("test.attr");
+ ASSERT_TRUE(moduleAttr);
+ EXPECT_EQ(moduleAttr.getValue(), StringRef("bar"));
}
std::unique_ptr<MLIRContext> context;
+ std::unique_ptr<OpBuilder> builder;
private:
constexpr static llvm::StringLiteral kInput = R"MLIR(
- module {
+ module attributes { test.attr = @bar } {
test.conversion_func_op private @foo() {
"test.conversion_call_op"() { callee=@bar } : () -> ()
"test.return"() : () -> ()
@@ -81,7 +89,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
- barOp, StringAttr::get(context.get(), "baz"), module);
+ barOp, builder->getStringAttr("baz"), module);
});
}
@@ -90,8 +98,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
- StringAttr::get(context.get(), "bar"),
- StringAttr::get(context.get(), "baz"), module);
+ builder->getStringAttr("bar"), builder->getStringAttr("baz"), module);
});
}
@@ -100,7 +107,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
- barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
+ barOp, builder->getStringAttr("baz"), &module->getRegion(0));
});
}
@@ -108,9 +115,9 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
// Symbol as `StringAttr`, rename within module body.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
- return symbolTable.replaceAllSymbolUses(
- StringAttr::get(context.get(), "bar"),
- StringAttr::get(context.get(), "baz"), &module->getRegion(0));
+ return symbolTable.replaceAllSymbolUses(builder->getStringAttr("bar"),
+ builder->getStringAttr("baz"),
+ &module->getRegion(0));
});
}
@@ -119,7 +126,7 @@ TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
- barOp, StringAttr::get(context.get(), "baz"), fooOp);
+ barOp, builder->getStringAttr("baz"), fooOp);
});
}
@@ -128,8 +135,7 @@ TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
- StringAttr::get(context.get(), "bar"),
- StringAttr::get(context.get(), "baz"), fooOp);
+ builder->getStringAttr("bar"), builder->getStringAttr("baz"), fooOp);
});
}
More information about the Mlir-commits
mailing list