[Mlir-commits] [mlir] [mlir][tblgen] Fix region and successor references in custom directives (PR #146242)

Henrich Lauko llvmlistbot at llvm.org
Tue Jul 1 00:30:05 PDT 2025


https://github.com/xlauko updated https://github.com/llvm/llvm-project/pull/146242

>From 350e6d46aba58a4b2249e749f0b56c702caaeff5 Mon Sep 17 00:00:00 2001
From: xlauko <xlauko at mail.muni.cz>
Date: Sat, 28 Jun 2025 22:58:01 +0200
Subject: [PATCH] [mlir][tblgen] Fix region and successor references in custom
 directives

Previously, references to regions and successors were incorrectly disallowed outside the top-level assembly form. This change enables the use of bound regions and successors as variables in custom directives.
---
 mlir/test/IR/region.mlir                      |  7 ++++++
 .../test/lib/Dialect/Test/TestFormatUtils.cpp | 23 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestFormatUtils.h  | 18 +++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         | 18 +++++++++++++++
 mlir/test/mlir-tblgen/op-format-spec.td       | 13 +++++++++++
 mlir/test/mlir-tblgen/op-format.td            | 20 ++++++++++++++++
 mlir/tools/mlir-tblgen/OpFormatGen.cpp        | 20 +++++++++-------
 7 files changed, 111 insertions(+), 8 deletions(-)

diff --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir
index 0b959915d6bbb..e2088817c5204 100644
--- a/mlir/test/IR/region.mlir
+++ b/mlir/test/IR/region.mlir
@@ -106,3 +106,10 @@ func.func @named_region_has_wrong_number_of_blocks() {
 test.single_no_terminator_custom_asm_op {
   "important_dont_drop"() : () -> ()
 }
+
+// -----
+
+// CHECK: test.dummy_op_with_region_ref
+test.dummy_op_with_region_ref {
+  ^bb0:
+}
diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp
index 9ed1b3a47be36..70bab21b83256 100644
--- a/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp
+++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp
@@ -381,3 +381,26 @@ void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type,
                               Attribute attr) {
   printer.printAttributeWithoutType(attr);
 }
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveDummyRegionRef
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseDummyRegionRef(OpAsmParser &parser, Region &region) {
+  return success();
+}
+
+void test::printDummyRegionRef(OpAsmPrinter &printer, Operation *op,
+                               Region &region) { /* do nothing */ }
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveDummySuccessorRef
+//===----------------------------------------------------------------------===//
+
+ParseResult test::parseDummySuccessorRef(OpAsmParser &parser,
+                                         Block *successor) {
+  return success();
+}
+
+void test::printDummySuccessorRef(OpAsmPrinter &printer, Operation *op,
+                                  Block *successor) { /* do nothing */ }
diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.h b/mlir/test/lib/Dialect/Test/TestFormatUtils.h
index 6d4df7d82ffa5..e914f9a27b79b 100644
--- a/mlir/test/lib/Dialect/Test/TestFormatUtils.h
+++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.h
@@ -207,6 +207,24 @@ mlir::ParseResult parseAttrElideType(mlir::AsmParser &parser,
 void printAttrElideType(mlir::AsmPrinter &printer, mlir::Operation *op,
                         mlir::TypeAttr type, mlir::Attribute attr);
 
+//===----------------------------------------------------------------------===//
+// CustomDirectiveDummyRegionRef
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseDummyRegionRef(mlir::OpAsmParser &parser,
+                                      mlir::Region &region);
+void printDummyRegionRef(mlir::OpAsmPrinter &printer, mlir::Operation *op,
+                         mlir::Region &region);
+
+//===----------------------------------------------------------------------===//
+// CustomDirectiveDummySuccessorRef
+//===----------------------------------------------------------------------===//
+
+mlir::ParseResult parseDummySuccessorRef(mlir::OpAsmParser &parser,
+                                         mlir::Block *successor);
+void printDummySuccessorRef(mlir::OpAsmPrinter &printer, mlir::Operation *op,
+                            mlir::Block *successor);
+
 } // end namespace test
 
 #endif // MLIR_TESTFORMATUTILS_H
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 1c961d272f192..0ad5bfa9a58ab 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3665,4 +3665,22 @@ def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
   );
 }
 
+//===----------------------------------------------------------------------===//
+// Test assembly format references
+//===----------------------------------------------------------------------===//
+
+def TestOpWithRegionRef : TEST_Op<"dummy_op_with_region_ref", [NoTerminator]> {
+  let regions = (region AnyRegion:$body);
+  let assemblyFormat = [{
+    $body attr-dict custom<DummyRegionRef>(ref($body))
+  }];
+}
+
+def TestOpWithSuccessorRef : TEST_Op<"dummy_op_with_successor_ref"> {
+  let successors = (successor AnySuccessor:$successor);
+  let assemblyFormat = [{
+    $successor attr-dict custom<DummySuccessorRef>(ref($successor))
+  }];
+}
+
 #endif // TEST_OPS
diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 02bf65609b21a..03b63f42c7767 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -49,6 +49,19 @@ def DirectiveCustomValidD : TestFormat_Op<[{
 def DirectiveCustomValidE : TestFormat_Op<[{
   custom<MyDirective>(prop-dict) attr-dict
 }]>, Arguments<(ins UnitAttr:$flag)>;
+def DirectiveCustomValidF : TestFormat_Op<[{
+  $operand custom<MyDirective>(ref($operand)) attr-dict
+}]>, Arguments<(ins Optional<I64>:$operand)>;
+def DirectiveCustomValidG : TestFormat_Op<[{
+  $body custom<MyDirective>(ref($body)) attr-dict
+}]> {
+  let regions = (region AnyRegion:$body);
+}
+def DirectiveCustomValidH : TestFormat_Op<[{
+  $successor custom<MyDirective>(ref($successor)) attr-dict
+}]> {
+  let successors = (successor AnySuccessor:$successor);
+}
 
 //===----------------------------------------------------------------------===//
 // functional-type
diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td
index 09e068b91a40b..1790737a3a349 100644
--- a/mlir/test/mlir-tblgen/op-format.td
+++ b/mlir/test/mlir-tblgen/op-format.td
@@ -109,3 +109,23 @@ def OptionalGroupC : TestFormat_Op<[{
 def OptionalGroupD : TestFormat_Op<[{
   (custom<Custom>($a, $b)^)? attr-dict
 }], [AttrSizedOperandSegments]>, Arguments<(ins Optional<I64>:$a, Optional<I64>:$b)>;
+
+// CHECK-LABEL: RegionRef::parse
+// CHECK:   auto odsResult = parseCustom(parser, *bodyRegion);
+// CHECK-LABEL: RegionRef::print
+// CHECK:   printCustom(_odsPrinter, *this, getBody());
+def RegionRef : TestFormat_Op<[{
+  $body custom<Custom>(ref($body)) attr-dict
+}]> {
+  let regions = (region AnyRegion:$body);
+}
+
+// CHECK-LABEL: SuccessorRef::parse
+// CHECK:   auto odsResult = parseCustom(parser, successorSuccessor);
+// CHECK-LABEL: SuccessorRef::print
+// CHECK:   printCustom(_odsPrinter, *this, getSuccessor());
+def SuccessorRef : TestFormat_Op<[{
+  $successor custom<Custom>(ref($successor)) attr-dict
+}]> {
+  let successors = (successor AnySuccessor:$successor);
+}
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index d27814bc4541e..14af7787a833e 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -3376,11 +3376,13 @@ OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
     if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
       if (hasAllRegions || !seenRegions.insert(region).second)
         return emitError(loc, "region '" + name + "' is already bound");
-    } else if (ctx == RefDirectiveContext && !seenRegions.count(region)) {
-      return emitError(loc, "region '" + name +
-                                "' must be bound before it is referenced");
+    } else if (ctx == RefDirectiveContext) {
+      if (!seenRegions.count(region))
+        return emitError(loc, "region '" + name +
+                                  "' must be bound before it is referenced");
     } else {
-      return emitError(loc, "regions can only be used at the top level");
+      return emitError(loc, "regions can only be used at the top level "
+                            "or in a ref directive");
     }
     return create<RegionVariable>(region);
   }
@@ -3396,11 +3398,13 @@ OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
     if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
       if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
         return emitError(loc, "successor '" + name + "' is already bound");
-    } else if (ctx == RefDirectiveContext && !seenSuccessors.count(successor)) {
-      return emitError(loc, "successor '" + name +
-                                "' must be bound before it is referenced");
+    } else if (ctx == RefDirectiveContext) {
+      if (!seenSuccessors.count(successor))
+        return emitError(loc, "successor '" + name +
+                                  "' must be bound before it is referenced");
     } else {
-      return emitError(loc, "successors can only be used at the top level");
+      return emitError(loc, "successors can only be used at the top level "
+                            "or in a ref directive");
     }
 
     return create<SuccessorVariable>(successor);



More information about the Mlir-commits mailing list