[Mlir-commits] [mlir] [mlir][Symbol] Add verification that symbol's parent is a SymbolTable (PR #80590)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Feb 3 23:41:11 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Joshua Cao (caojoshua)

<details>
<summary>Changes</summary>

Following the discussion in https://discourse.llvm.org/t/symboltable-and-symbol-parent-child-relationship/75446, we should enforce that a symbol's immediate parent is a symbol table.

I changed some tests to pass the verification. In most cases, we can wrap the func with a module, change the func to another op with regions i.e. scf.if, or change the expected error message.

---
Full diff: https://github.com/llvm/llvm-project/pull/80590.diff


12 Files Affected:

- (modified) mlir/include/mlir/IR/SymbolInterfaces.td (+5) 
- (modified) mlir/test/Dialect/LLVMIR/global.mlir (+1-1) 
- (modified) mlir/test/Dialect/Linalg/transform-op-replace.mlir (+4-2) 
- (modified) mlir/test/Dialect/Transform/ops-invalid.mlir (+1-2) 
- (modified) mlir/test/IR/invalid-func-op.mlir (+2-2) 
- (modified) mlir/test/IR/region.mlir (+3-4) 
- (modified) mlir/test/IR/traits.mlir (+16-17) 
- (modified) mlir/test/Transforms/canonicalize-dce.mlir (+7-7) 
- (modified) mlir/test/Transforms/canonicalize.mlir (+6-7) 
- (modified) mlir/test/Transforms/constant-fold.mlir (+7-4) 
- (modified) mlir/test/Transforms/cse.mlir (+7-4) 
- (modified) mlir/test/Transforms/test-legalizer-full.mlir (+5-3) 


``````````diff
diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index 844601f8f6837c..0bd5de9f18920e 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -171,6 +171,11 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
     if (concreteOp.isDeclaration() && concreteOp.isPublic())
       return concreteOp.emitOpError("symbol declaration cannot have public "
              "visibility");
+    auto parent = $_op->getParentOp();
+    if (parent && !parent->hasTrait<OpTrait::SymbolTable>()) {
+      return concreteOp.emitOpError("symbol's parent must have the SymbolTable "
+             "trait");
+    }
     return success();
   }];
 
diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir
index 0649e814bfdfc0..3fa7636d4dd686 100644
--- a/mlir/test/Dialect/LLVMIR/global.mlir
+++ b/mlir/test/Dialect/LLVMIR/global.mlir
@@ -132,7 +132,7 @@ llvm.mlir.global internal constant @constant(37.0) : !llvm.label
 // -----
 
 func.func @foo() {
-  // expected-error @+1 {{must appear at the module level}}
+  // expected-error @+1 {{op symbol's parent must have the SymbolTable trait}}
   llvm.mlir.global internal @bar(42) : i32
 
   return
diff --git a/mlir/test/Dialect/Linalg/transform-op-replace.mlir b/mlir/test/Dialect/Linalg/transform-op-replace.mlir
index 2801522e81ac2c..1a40912977dec2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-replace.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-replace.mlir
@@ -12,8 +12,10 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.structured.replace %0 {
-      func.func @foo() {
-        "dummy_op"() : () -> ()
+      builtin.module {
+        func.func @foo() {
+          "dummy_op"() : () -> ()
+        }
       }
     } : (!transform.any_op) -> !transform.any_op
     transform.yield
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index e3f5bcf403f2ad..73a5f36af92952 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -433,10 +433,9 @@ module {
 // -----
 
 module attributes { transform.with_named_sequence} {
-  // expected-note @below {{ancestor transform op}}
   transform.sequence failures(suppress) {
   ^bb0(%arg0: !transform.any_op):
-    // expected-error @below {{cannot be defined inside another transform op}}
+    // expected-error @below {{op symbol's parent must have the SymbolTable trai}}
     transform.named_sequence @nested() {
       transform.yield
     }
diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir
index d995689ebb8d0b..8fd7af22e9598b 100644
--- a/mlir/test/IR/invalid-func-op.mlir
+++ b/mlir/test/IR/invalid-func-op.mlir
@@ -31,7 +31,7 @@ func.func @func_op() {
 // -----
 
 func.func @func_op() {
-  // expected-error at +1 {{entry block must have 1 arguments to match function signature}}
+  // expected-error at +1 {{op symbol's parent must have the SymbolTable trait}}
   func.func @mixed_named_arguments(f32) {
   ^entry:
     return
@@ -42,7 +42,7 @@ func.func @func_op() {
 // -----
 
 func.func @func_op() {
-  // expected-error at +1 {{type of entry block argument #0('i32') must match the type of the corresponding argument in function signature('f32')}}
+  // expected-error at +1 {{op symbol's parent must have the SymbolTable trait}}
   func.func @mixed_named_arguments(f32) {
   ^entry(%arg : i32):
     return
diff --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir
index bf4b1bb4e5ab1d..0b959915d6bbbe 100644
--- a/mlir/test/IR/region.mlir
+++ b/mlir/test/IR/region.mlir
@@ -87,18 +87,17 @@ func.func @named_region_has_wrong_number_of_blocks() {
 // CHECK: test.single_no_terminator_op
 "test.single_no_terminator_op"() (
   {
-    func.func @foo1() { return }
-    func.func @foo2() { return }
+    %foo = arith.constant 1 : i32
   }
 ) : () -> ()
 
 // CHECK: test.variadic_no_terminator_op
 "test.variadic_no_terminator_op"() (
   {
-    func.func @foo1() { return }
+    %foo = arith.constant 1 : i32
   },
   {
-    func.func @foo2() { return }
+    %bar = arith.constant 1 : i32
   }
 ) : () -> ()
 
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 0402ebe7587508..1e046706379cdb 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -572,15 +572,13 @@ func.func @failedHasDominanceScopeOutsideDominanceFreeScope() -> () {
 
 // Ensure that SSACFG regions of operations in GRAPH regions are
 // checked for dominance
-func.func @illegalInsideDominanceFreeScope() -> () {
+func.func @illegalInsideDominanceFreeScope(%cond: i1) -> () {
   test.graph_region {
-    func.func @test() -> i1 {
-    ^bb1:
+    scf.if %cond {
       // expected-error @+1 {{operand #0 does not dominate this use}}
       %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
       // expected-note @+1 {{operand defined here}}
-	   %1 = "baz"(%2#0) : (i1) -> (i64)
-      return %2#1 : i1
+      %1 = "baz"(%2#0) : (i1) -> (i64)
     }
     "terminator"() : () -> ()
   }
@@ -591,20 +589,21 @@ func.func @illegalInsideDominanceFreeScope() -> () {
 
 // Ensure that SSACFG regions of operations in GRAPH regions are
 // checked for dominance
-func.func @illegalCDFGInsideDominanceFreeScope() -> () {
+func.func @illegalCFGInsideDominanceFreeScope(%cond: i1) -> () {
   test.graph_region {
-    func.func @test() -> i1 {
-    ^bb1:
-      // expected-error @+1 {{operand #0 does not dominate this use}}
-      %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
-      cf.br ^bb4
-    ^bb2:
-      cf.br ^bb2
-    ^bb4:
-      %1 = "foo"() : ()->i64   // expected-note {{operand defined here}}
-		return %2#1 : i1
+    scf.if %cond {
+      "test.ssacfg_region"() ({
+      ^bb1:
+        // expected-error @+1 {{operand #0 does not dominate this use}}
+        %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
+        cf.br ^bb4
+      ^bb2:
+        cf.br ^bb2
+      ^bb4:
+        %1 = "foo"() : ()->i64   // expected-note {{operand defined here}}
+      }) : () -> ()
     }
-     "terminator"() : () -> ()
+    "terminator"() : () -> ()
   }
   return
 }
diff --git a/mlir/test/Transforms/canonicalize-dce.mlir b/mlir/test/Transforms/canonicalize-dce.mlir
index 46545d2e9fd510..3048a7fed636b5 100644
--- a/mlir/test/Transforms/canonicalize-dce.mlir
+++ b/mlir/test/Transforms/canonicalize-dce.mlir
@@ -77,15 +77,15 @@ func.func @f(%arg0: f32, %pred: i1) {
 
 // Test case: Recursively DCE into enclosed regions.
 
-// CHECK:      func @f(%arg0: f32)
-// CHECK-NEXT:   func @g(%arg1: f32)
-// CHECK-NEXT:     return
+// CHECK:      func.func @f(%arg0: f32)
+// CHECK-NOT:     arith.addf
 
 func.func @f(%arg0: f32) {
-  func.func @g(%arg1: f32) {
-    %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32
-    return
-  }
+  "test.region"() (
+    {
+      %0 = "arith.addf"(%arg0, %arg0) : (f32, f32) -> f32
+    }
+  ) : () -> ()
   return
 }
 
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 9b578e6c2631a7..2cf86b50d432f6 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -424,16 +424,15 @@ func.func @write_only_alloca_fold(%v: f32) {
 // CHECK-LABEL: func @dead_block_elim
 func.func @dead_block_elim() {
   // CHECK-NOT: ^bb
-  func.func @nested() {
-    return
+  builtin.module {
+    func.func @nested() {
+      return
 
-  ^bb1:
-    return
+    ^bb1:
+      return
+    }
   }
   return
-
-^bb1:
-  return
 }
 
 // CHECK-LABEL: func @dyn_shape_fold(%arg0: index, %arg1: index)
diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index 45ee03fa31d25f..253163f2af9110 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -756,12 +756,15 @@ func.func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
 
 // CHECK-LABEL: func @nested_isolated_region
 func.func @nested_isolated_region() {
+  // CHECK-NEXT: builtin.module {
   // CHECK-NEXT: func @isolated_op
   // CHECK-NEXT: arith.constant 2
-  func.func @isolated_op() {
-    %0 = arith.constant 1 : i32
-    %2 = arith.addi %0, %0 : i32
-    "foo.yield"(%2) : (i32) -> ()
+  builtin.module {
+    func.func @isolated_op() {
+      %0 = arith.constant 1 : i32
+      %2 = arith.addi %0, %0 : i32
+      "foo.yield"(%2) : (i32) -> ()
+    }
   }
 
   // CHECK: "foo.unknown_region"
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
index c764d2b9bd57d8..11a33102684733 100644
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -228,11 +228,14 @@ func.func @nested_isolated() -> i32 {
   // CHECK-NEXT: arith.constant 1
   %0 = arith.constant 1 : i32
 
+  // CHECK-NEXT: builtin.module
   // CHECK-NEXT: @nested_func
-  func.func @nested_func() {
-    // CHECK-NEXT: arith.constant 1
-    %foo = arith.constant 1 : i32
-    "foo.yield"(%foo) : (i32) -> ()
+  builtin.module {
+    func.func @nested_func() {
+      // CHECK-NEXT: arith.constant 1
+      %foo = arith.constant 1 : i32
+      "foo.yield"(%foo) : (i32) -> ()
+    }
   }
 
   // CHECK: "foo.region"
diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index ecb17d5f1b67d4..4268f18e611c0a 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -37,9 +37,11 @@ func.func @recursively_legal_invalid_op() {
   }
   /// Operation that is dynamically legal, i.e. the function has a pattern
   /// applied to legalize the argument type before it becomes recursively legal.
-  func.func @dynamic_func(%arg: i64) attributes {test.recursively_legal} {
-    %ignored = "test.illegal_op_f"() : () -> (i32)
-    "test.return"() : () -> ()
+  builtin.module {
+    func.func @dynamic_func(%arg: i64) attributes {test.recursively_legal} {
+      %ignored = "test.illegal_op_f"() : () -> (i32)
+      "test.return"() : () -> ()
+    }
   }
 
   "test.return"() : () -> ()

``````````

</details>


https://github.com/llvm/llvm-project/pull/80590


More information about the Mlir-commits mailing list