[Mlir-commits] [mlir] 1d0d7da - [mlir] Add symbol user attribute interface. (#153206)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 10 06:13:37 PST 2025


Author: Jacques Pienaar
Date: 2025-12-10T14:13:33Z
New Revision: 1d0d7da57cdf9f29ab5bb3aca7be46302ea2b9fe

URL: https://github.com/llvm/llvm-project/commit/1d0d7da57cdf9f29ab5bb3aca7be46302ea2b9fe
DIFF: https://github.com/llvm/llvm-project/commit/1d0d7da57cdf9f29ab5bb3aca7be46302ea2b9fe.diff

LOG: [mlir] Add symbol user attribute interface. (#153206)

Enables verification of attributes, independent of op, that references symbols.
This enables verifying Attribute with symbol usage independent of operation
attached to (e.g., the validity is on the Attribute independent of the operation).

---------

Co-authored-by: Mehdi Amini <joker.eph at gmail.com>

Added: 
    mlir/test/IR/test-verifiers-attr.mlir

Modified: 
    mlir/include/mlir/IR/CMakeLists.txt
    mlir/include/mlir/IR/SymbolInterfaces.td
    mlir/include/mlir/IR/SymbolTable.h
    mlir/lib/IR/SymbolTable.cpp
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestAttributes.cpp
    mlir/test/lib/Dialect/Test/TestAttributes.h
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 683e2feaddef2..0b3079cde568d 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,4 +1,7 @@
 add_mlir_interface(SymbolInterfaces)
+set(LLVM_TARGET_DEFINITIONS SymbolInterfaces.td)
+mlir_tablegen(SymbolInterfacesAttrInterface.h.inc -gen-attr-interface-decls)
+mlir_tablegen(SymbolInterfacesAttrInterface.cpp.inc -gen-attr-interface-defs)
 add_mlir_interface(RegionKindInterface)
 
 set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)

diff  --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index b3aafe063d376..ebe0c26637ad3 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -210,7 +210,7 @@ def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> {
     This interface describes an operation that may use a `Symbol`. This
     interface allows for users of symbols to hook into verification and other
     symbol related utilities that are either costly or otherwise disallowed
-    within a traditional operation.
+    within an operation.
   }];
   let cppNamespace = "::mlir";
 
@@ -222,6 +222,25 @@ def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> {
   ];
 }
 
+def SymbolUserAttrInterface : AttrInterface<"SymbolUserAttrInterface"> {
+  let description = [{
+    This interface describes an attribute that may use a `Symbol`. This
+    interface allows for users of symbols to hook into verification and other
+    symbol related utilities that are either costly or otherwise disallowed
+    within an operation (e.g., recreating symbol users per op verified rather
+    than per symbol table, or querying symbols usage of sibblings).
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<"Verify the symbol uses held by this attribute of this operation.",
+      "::llvm::LogicalResult", "verifySymbolUses",
+      (ins "::mlir::Operation *":$op,
+           "::mlir::SymbolTableCollection &":$symbolTable)
+    >,
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // Symbol Traits
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index e4622354b8980..a174062d8d019 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -499,5 +499,6 @@ ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser,
 
 /// Include the generated symbol interfaces.
 #include "mlir/IR/SymbolInterfaces.h.inc"
+#include "mlir/IR/SymbolInterfacesAttrInterface.h.inc"
 
 #endif // MLIR_IR_SYMBOLTABLE_H

diff  --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 87b47992905e0..9f5dd2c9e3b72 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -511,7 +511,14 @@ LogicalResult detail::verifySymbolTable(Operation *op) {
   SymbolTableCollection symbolTable;
   auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
     if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
-      return WalkResult(user.verifySymbolUses(symbolTable));
+      if (failed(user.verifySymbolUses(symbolTable)))
+        return WalkResult::interrupt();
+    for (auto &attr : op->getDiscardableAttrs()) {
+      if (auto user = dyn_cast<SymbolUserAttrInterface>(attr.getValue())) {
+        if (failed(user.verifySymbolUses(op, symbolTable)))
+          return WalkResult::interrupt();
+      }
+    }
     return WalkResult::advance();
   };
 
@@ -1132,3 +1139,4 @@ ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser,
 
 /// Include the generated symbol interfaces.
 #include "mlir/IR/SymbolInterfaces.cpp.inc"
+#include "mlir/IR/SymbolInterfacesAttrInterface.cpp.inc"

diff  --git a/mlir/test/IR/test-verifiers-attr.mlir b/mlir/test/IR/test-verifiers-attr.mlir
new file mode 100644
index 0000000000000..4cbf7babbcf60
--- /dev/null
+++ b/mlir/test/IR/test-verifiers-attr.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// Test basic symbol verification using discardable attribute.
+module {
+  func.func @existing_symbol() { return }
+  
+  func.func @test() attributes {symbol_ref = #test.symbol_ref_attr<@existing_symbol>} { return }
+}
+
+// -----
+
+// Test invalid symbol reference, symbol does not exist.
+module {
+  // expected-error at +1 {{TestSymbolRefAttr::verifySymbolUses: '@non_existent_symbol' does not reference a valid symbol}}
+  func.func @test() attributes {symbol_ref = #test.symbol_ref_attr<@non_existent_symbol>} { return }
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 9e7e4f883b576..f7e2273954693 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -22,6 +22,7 @@ include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
 include "mlir/IR/TensorEncoding.td"
 
 // All of the attributes will extend this class.
@@ -456,4 +457,17 @@ def TestMemRefLayoutAttr : Test_Attr<"TestMemRefLayout",
   let assemblyFormat = "`<` $dummy `>`";
 }
 
+// Test attribute that implements SymbolUserAttrInterface.
+def TestSymbolRefAttr : Test_Attr<"TestSymbolRef",
+    [DeclareAttrInterfaceMethods<SymbolUserAttrInterface>]> {
+  let mnemonic = "symbol_ref_attr";
+  let summary = "Test attribute that references a symbol";
+  let description = [{
+    This attribute holds a reference to a symbol and implements
+    SymbolUserAttrInterface to verify that the referenced symbol exists.
+  }];
+  let parameters = (ins "::mlir::FlatSymbolRefAttr":$symbol);
+  let assemblyFormat = "`<` $symbol `>`";
+}
+
 #endif // TEST_ATTRDEFS

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 9db7b01dd193b..0576809ccdebd 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -223,6 +223,25 @@ LogicalResult TestCopyCountAttr::verify(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TestSymbolRefAttr
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+TestSymbolRefAttr::verifySymbolUses(Operation *op,
+                                    SymbolTableCollection &symbolTable) const {
+  // Verify that the referenced symbol exists
+  if (!symbolTable.lookupNearestSymbolFrom<SymbolOpInterface>(op, getSymbol()))
+    return op->emitOpError()
+           << "TestSymbolRefAttr::verifySymbolUses: '" << getSymbol()
+           << "' does not reference a valid symbol";
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Generated Attribute Definitions
+//===----------------------------------------------------------------------===//
+
 //===----------------------------------------------------------------------===//
 // CopyCountAttr Implementation
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index 0ad5ab641c6d0..d4887ebd33c0a 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -20,10 +20,11 @@
 #include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/IR/Attributes.h"
-#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Dialect.h"
-#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/DialectResourceBlobManager.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/TensorEncoding.h"
 
 // generated files require above includes to come first

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index e6f577d6665c2..9816d9c411cb3 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -85,6 +85,8 @@ exports_files(glob(["include/**/*.td"]))
         tbl_outs = {
             "include/mlir/IR/" + name + ".h.inc": ["-gen-op-interface-decls"],
             "include/mlir/IR/" + name + ".cpp.inc": ["-gen-op-interface-defs"],
+            "include/mlir/IR/" + name + "AttrInterface.h.inc": ["-gen-attr-interface-decls"],
+            "include/mlir/IR/" + name + "AttrInterface.cpp.inc": ["-gen-attr-interface-defs"],
         },
         tblgen = ":mlir-tblgen",
         td_file = "include/mlir/IR/" + name + ".td",


        


More information about the Mlir-commits mailing list