[Mlir-commits] [mlir] [mlir][Symbol] Add verification that symbol's parent is a SymbolTable (PR #80590)
Joshua Cao
llvmlistbot at llvm.org
Sun Feb 4 14:11:26 PST 2024
https://github.com/caojoshua updated https://github.com/llvm/llvm-project/pull/80590
>From 7d8add4fc757072f123cf4b44cd125bc91eb275c Mon Sep 17 00:00:00 2001
From: Joshua Cao <cao.joshua at yahoo.com>
Date: Sun, 3 Dec 2023 14:12:11 -0800
Subject: [PATCH 1/2] [mlir][Symbol] Add verification that symbol's parent is a
SymbolTable
---
mlir/include/mlir/IR/SymbolInterfaces.td | 5 +++
mlir/test/Dialect/LLVMIR/global.mlir | 2 +-
.../Dialect/Linalg/transform-op-replace.mlir | 6 ++--
mlir/test/Dialect/Transform/ops-invalid.mlir | 3 +-
mlir/test/IR/invalid-func-op.mlir | 4 +--
mlir/test/IR/region.mlir | 7 ++--
mlir/test/IR/traits.mlir | 33 +++++++++----------
mlir/test/Transforms/canonicalize-dce.mlir | 14 ++++----
mlir/test/Transforms/canonicalize.mlir | 13 ++++----
mlir/test/Transforms/constant-fold.mlir | 11 ++++---
mlir/test/Transforms/cse.mlir | 11 ++++---
mlir/test/Transforms/test-legalizer-full.mlir | 8 +++--
12 files changed, 64 insertions(+), 53 deletions(-)
diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index 844601f8f6837..0bd5de9f18920 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -171,6 +171,11 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
if (concreteOp.isDeclaration() && concreteOp.isPublic())
return concreteOp.emitOpError("symbol declaration cannot have public "
"visibility");
+ auto parent = $_op->getParentOp();
+ if (parent && !parent->hasTrait<OpTrait::SymbolTable>()) {
+ return concreteOp.emitOpError("symbol's parent must have the SymbolTable "
+ "trait");
+ }
return success();
}];
diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir
index 0649e814bfdfc..3fa7636d4dd68 100644
--- a/mlir/test/Dialect/LLVMIR/global.mlir
+++ b/mlir/test/Dialect/LLVMIR/global.mlir
@@ -132,7 +132,7 @@ llvm.mlir.global internal constant @constant(37.0) : !llvm.label
// -----
func.func @foo() {
- // expected-error @+1 {{must appear at the module level}}
+ // expected-error @+1 {{op symbol's parent must have the SymbolTable trait}}
llvm.mlir.global internal @bar(42) : i32
return
diff --git a/mlir/test/Dialect/Linalg/transform-op-replace.mlir b/mlir/test/Dialect/Linalg/transform-op-replace.mlir
index 2801522e81ac2..1a40912977dec 100644
--- a/mlir/test/Dialect/Linalg/transform-op-replace.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-replace.mlir
@@ -12,8 +12,10 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.replace %0 {
- func.func @foo() {
- "dummy_op"() : () -> ()
+ builtin.module {
+ func.func @foo() {
+ "dummy_op"() : () -> ()
+ }
}
} : (!transform.any_op) -> !transform.any_op
transform.yield
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index e3f5bcf403f2a..73a5f36af9295 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -433,10 +433,9 @@ module {
// -----
module attributes { transform.with_named_sequence} {
- // expected-note @below {{ancestor transform op}}
transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
- // expected-error @below {{cannot be defined inside another transform op}}
+ // expected-error @below {{op symbol's parent must have the SymbolTable trai}}
transform.named_sequence @nested() {
transform.yield
}
diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir
index d995689ebb8d0..8fd7af22e9598 100644
--- a/mlir/test/IR/invalid-func-op.mlir
+++ b/mlir/test/IR/invalid-func-op.mlir
@@ -31,7 +31,7 @@ func.func @func_op() {
// -----
func.func @func_op() {
- // expected-error at +1 {{entry block must have 1 arguments to match function signature}}
+ // expected-error at +1 {{op symbol's parent must have the SymbolTable trait}}
func.func @mixed_named_arguments(f32) {
^entry:
return
@@ -42,7 +42,7 @@ func.func @func_op() {
// -----
func.func @func_op() {
- // expected-error at +1 {{type of entry block argument #0('i32') must match the type of the corresponding argument in function signature('f32')}}
+ // expected-error at +1 {{op symbol's parent must have the SymbolTable trait}}
func.func @mixed_named_arguments(f32) {
^entry(%arg : i32):
return
diff --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir
index bf4b1bb4e5ab1..0b959915d6bbb 100644
--- a/mlir/test/IR/region.mlir
+++ b/mlir/test/IR/region.mlir
@@ -87,18 +87,17 @@ func.func @named_region_has_wrong_number_of_blocks() {
// CHECK: test.single_no_terminator_op
"test.single_no_terminator_op"() (
{
- func.func @foo1() { return }
- func.func @foo2() { return }
+ %foo = arith.constant 1 : i32
}
) : () -> ()
// CHECK: test.variadic_no_terminator_op
"test.variadic_no_terminator_op"() (
{
- func.func @foo1() { return }
+ %foo = arith.constant 1 : i32
},
{
- func.func @foo2() { return }
+ %bar = arith.constant 1 : i32
}
) : () -> ()
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 0402ebe758750..1e046706379cd 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -572,15 +572,13 @@ func.func @failedHasDominanceScopeOutsideDominanceFreeScope() -> () {
// Ensure that SSACFG regions of operations in GRAPH regions are
// checked for dominance
-func.func @illegalInsideDominanceFreeScope() -> () {
+func.func @illegalInsideDominanceFreeScope(%cond: i1) -> () {
test.graph_region {
- func.func @test() -> i1 {
- ^bb1:
+ scf.if %cond {
// expected-error @+1 {{operand #0 does not dominate this use}}
%2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
// expected-note @+1 {{operand defined here}}
- %1 = "baz"(%2#0) : (i1) -> (i64)
- return %2#1 : i1
+ %1 = "baz"(%2#0) : (i1) -> (i64)
}
"terminator"() : () -> ()
}
@@ -591,20 +589,21 @@ func.func @illegalInsideDominanceFreeScope() -> () {
// Ensure that SSACFG regions of operations in GRAPH regions are
// checked for dominance
-func.func @illegalCDFGInsideDominanceFreeScope() -> () {
+func.func @illegalCFGInsideDominanceFreeScope(%cond: i1) -> () {
test.graph_region {
- func.func @test() -> i1 {
- ^bb1:
- // expected-error @+1 {{operand #0 does not dominate this use}}
- %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
- cf.br ^bb4
- ^bb2:
- cf.br ^bb2
- ^bb4:
- %1 = "foo"() : ()->i64 // expected-note {{operand defined here}}
- return %2#1 : i1
+ scf.if %cond {
+ "test.ssacfg_region"() ({
+ ^bb1:
+ // expected-error @+1 {{operand #0 does not dominate this use}}
+ %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
+ cf.br ^bb4
+ ^bb2:
+ cf.br ^bb2
+ ^bb4:
+ %1 = "foo"() : ()->i64 // expected-note {{operand defined here}}
+ }) : () -> ()
}
- "terminator"() : () -> ()
+ "terminator"() : () -> ()
}
return
}
diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir
index 46545d2e9fd51..3048a7fed636b 100644
--- a/mlir/test/Transforms/canonicalize-dce.mlir
+++ b/mlir/test/Transforms/canonicalize-dce.mlir
@@ -77,15 +77,15 @@ func.func @f(%arg0: f32, %pred: i1) {
// Test case: Recursively DCE into enclosed regions.
-// CHECK: func @f(%arg0: f32)
-// CHECK-NEXT: func @g(%arg1: f32)
-// CHECK-NEXT: return
+// CHECK: func.func @f(%arg0: f32)
+// CHECK-NOT: arith.addf
func.func @f(%arg0: f32) {
- func.func @g(%arg1: f32) {
- %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32
- return
- }
+ "test.region"() (
+ {
+ %0 = "arith.addf"(%arg0, %arg0) : (f32, f32) -> f32
+ }
+ ) : () -> ()
return
}
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 9b578e6c2631a..2cf86b50d432f 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -424,16 +424,15 @@ func.func @write_only_alloca_fold(%v: f32) {
// CHECK-LABEL: func @dead_block_elim
func.func @dead_block_elim() {
// CHECK-NOT: ^bb
- func.func @nested() {
- return
+ builtin.module {
+ func.func @nested() {
+ return
- ^bb1:
- return
+ ^bb1:
+ return
+ }
}
return
-
-^bb1:
- return
}
// CHECK-LABEL: func @dyn_shape_fold(%arg0: index, %arg1: index)
diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index 45ee03fa31d25..253163f2af911 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -756,12 +756,15 @@ func.func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
// CHECK-LABEL: func @nested_isolated_region
func.func @nested_isolated_region() {
+ // CHECK-NEXT: builtin.module {
// CHECK-NEXT: func @isolated_op
// CHECK-NEXT: arith.constant 2
- func.func @isolated_op() {
- %0 = arith.constant 1 : i32
- %2 = arith.addi %0, %0 : i32
- "foo.yield"(%2) : (i32) -> ()
+ builtin.module {
+ func.func @isolated_op() {
+ %0 = arith.constant 1 : i32
+ %2 = arith.addi %0, %0 : i32
+ "foo.yield"(%2) : (i32) -> ()
+ }
}
// CHECK: "foo.unknown_region"
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index c764d2b9bd57d..11a3310268473 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -228,11 +228,14 @@ func.func @nested_isolated() -> i32 {
// CHECK-NEXT: arith.constant 1
%0 = arith.constant 1 : i32
+ // CHECK-NEXT: builtin.module
// CHECK-NEXT: @nested_func
- func.func @nested_func() {
- // CHECK-NEXT: arith.constant 1
- %foo = arith.constant 1 : i32
- "foo.yield"(%foo) : (i32) -> ()
+ builtin.module {
+ func.func @nested_func() {
+ // CHECK-NEXT: arith.constant 1
+ %foo = arith.constant 1 : i32
+ "foo.yield"(%foo) : (i32) -> ()
+ }
}
// CHECK: "foo.region"
diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index ecb17d5f1b67d..4268f18e611c0 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -37,9 +37,11 @@ func.func @recursively_legal_invalid_op() {
}
/// Operation that is dynamically legal, i.e. the function has a pattern
/// applied to legalize the argument type before it becomes recursively legal.
- func.func @dynamic_func(%arg: i64) attributes {test.recursively_legal} {
- %ignored = "test.illegal_op_f"() : () -> (i32)
- "test.return"() : () -> ()
+ builtin.module {
+ func.func @dynamic_func(%arg: i64) attributes {test.recursively_legal} {
+ %ignored = "test.illegal_op_f"() : () -> (i32)
+ "test.return"() : () -> ()
+ }
}
"test.return"() : () -> ()
>From 2cd26b443faacb745c62586b0e730203cb0b97dd Mon Sep 17 00:00:00 2001
From: Joshua Cao <cao.joshua at yahoo.com>
Date: Sun, 4 Feb 2024 14:11:20 -0800
Subject: [PATCH 2/2] Validate registered parent ops
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
mlir/include/mlir/IR/SymbolInterfaces.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index 0bd5de9f18920..60b38185fa8cc 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -172,7 +172,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
return concreteOp.emitOpError("symbol declaration cannot have public "
"visibility");
auto parent = $_op->getParentOp();
- if (parent && !parent->hasTrait<OpTrait::SymbolTable>()) {
+ if (parent && !parent->hasTrait<OpTrait::SymbolTable>() && parent->isRegistered()) {
return concreteOp.emitOpError("symbol's parent must have the SymbolTable "
"trait");
}
More information about the Mlir-commits
mailing list