[Mlir-commits] [mlir] 1d0d7da - [mlir] Add symbol user attribute interface. (#153206)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 10 06:13:37 PST 2025
Author: Jacques Pienaar
Date: 2025-12-10T14:13:33Z
New Revision: 1d0d7da57cdf9f29ab5bb3aca7be46302ea2b9fe
URL: https://github.com/llvm/llvm-project/commit/1d0d7da57cdf9f29ab5bb3aca7be46302ea2b9fe
DIFF: https://github.com/llvm/llvm-project/commit/1d0d7da57cdf9f29ab5bb3aca7be46302ea2b9fe.diff
LOG: [mlir] Add symbol user attribute interface. (#153206)
Enables verification of attributes, independent of op, that references symbols.
This enables verifying Attribute with symbol usage independent of operation
attached to (e.g., the validity is on the Attribute independent of the operation).
---------
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
Added:
mlir/test/IR/test-verifiers-attr.mlir
Modified:
mlir/include/mlir/IR/CMakeLists.txt
mlir/include/mlir/IR/SymbolInterfaces.td
mlir/include/mlir/IR/SymbolTable.h
mlir/lib/IR/SymbolTable.cpp
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/lib/Dialect/Test/TestAttributes.cpp
mlir/test/lib/Dialect/Test/TestAttributes.h
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 683e2feaddef2..0b3079cde568d 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -1,4 +1,7 @@
add_mlir_interface(SymbolInterfaces)
+set(LLVM_TARGET_DEFINITIONS SymbolInterfaces.td)
+mlir_tablegen(SymbolInterfacesAttrInterface.h.inc -gen-attr-interface-decls)
+mlir_tablegen(SymbolInterfacesAttrInterface.cpp.inc -gen-attr-interface-defs)
add_mlir_interface(RegionKindInterface)
set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index b3aafe063d376..ebe0c26637ad3 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -210,7 +210,7 @@ def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> {
This interface describes an operation that may use a `Symbol`. This
interface allows for users of symbols to hook into verification and other
symbol related utilities that are either costly or otherwise disallowed
- within a traditional operation.
+ within an operation.
}];
let cppNamespace = "::mlir";
@@ -222,6 +222,25 @@ def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> {
];
}
+def SymbolUserAttrInterface : AttrInterface<"SymbolUserAttrInterface"> {
+ let description = [{
+ This interface describes an attribute that may use a `Symbol`. This
+ interface allows for users of symbols to hook into verification and other
+ symbol related utilities that are either costly or otherwise disallowed
+ within an operation (e.g., recreating symbol users per op verified rather
+ than per symbol table, or querying symbols usage of sibblings).
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<"Verify the symbol uses held by this attribute of this operation.",
+ "::llvm::LogicalResult", "verifySymbolUses",
+ (ins "::mlir::Operation *":$op,
+ "::mlir::SymbolTableCollection &":$symbolTable)
+ >,
+ ];
+}
+
//===----------------------------------------------------------------------===//
// Symbol Traits
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index e4622354b8980..a174062d8d019 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -499,5 +499,6 @@ ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser,
/// Include the generated symbol interfaces.
#include "mlir/IR/SymbolInterfaces.h.inc"
+#include "mlir/IR/SymbolInterfacesAttrInterface.h.inc"
#endif // MLIR_IR_SYMBOLTABLE_H
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 87b47992905e0..9f5dd2c9e3b72 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -511,7 +511,14 @@ LogicalResult detail::verifySymbolTable(Operation *op) {
SymbolTableCollection symbolTable;
auto verifySymbolUserFn = [&](Operation *op) -> std::optional<WalkResult> {
if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
- return WalkResult(user.verifySymbolUses(symbolTable));
+ if (failed(user.verifySymbolUses(symbolTable)))
+ return WalkResult::interrupt();
+ for (auto &attr : op->getDiscardableAttrs()) {
+ if (auto user = dyn_cast<SymbolUserAttrInterface>(attr.getValue())) {
+ if (failed(user.verifySymbolUses(op, symbolTable)))
+ return WalkResult::interrupt();
+ }
+ }
return WalkResult::advance();
};
@@ -1132,3 +1139,4 @@ ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser,
/// Include the generated symbol interfaces.
#include "mlir/IR/SymbolInterfaces.cpp.inc"
+#include "mlir/IR/SymbolInterfacesAttrInterface.cpp.inc"
diff --git a/mlir/test/IR/test-verifiers-attr.mlir b/mlir/test/IR/test-verifiers-attr.mlir
new file mode 100644
index 0000000000000..4cbf7babbcf60
--- /dev/null
+++ b/mlir/test/IR/test-verifiers-attr.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// Test basic symbol verification using discardable attribute.
+module {
+ func.func @existing_symbol() { return }
+
+ func.func @test() attributes {symbol_ref = #test.symbol_ref_attr<@existing_symbol>} { return }
+}
+
+// -----
+
+// Test invalid symbol reference, symbol does not exist.
+module {
+ // expected-error at +1 {{TestSymbolRefAttr::verifySymbolUses: '@non_existent_symbol' does not reference a valid symbol}}
+ func.func @test() attributes {symbol_ref = #test.symbol_ref_attr<@non_existent_symbol>} { return }
+}
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 9e7e4f883b576..f7e2273954693 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -22,6 +22,7 @@ include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
include "mlir/IR/TensorEncoding.td"
// All of the attributes will extend this class.
@@ -456,4 +457,17 @@ def TestMemRefLayoutAttr : Test_Attr<"TestMemRefLayout",
let assemblyFormat = "`<` $dummy `>`";
}
+// Test attribute that implements SymbolUserAttrInterface.
+def TestSymbolRefAttr : Test_Attr<"TestSymbolRef",
+ [DeclareAttrInterfaceMethods<SymbolUserAttrInterface>]> {
+ let mnemonic = "symbol_ref_attr";
+ let summary = "Test attribute that references a symbol";
+ let description = [{
+ This attribute holds a reference to a symbol and implements
+ SymbolUserAttrInterface to verify that the referenced symbol exists.
+ }];
+ let parameters = (ins "::mlir::FlatSymbolRefAttr":$symbol);
+ let assemblyFormat = "`<` $symbol `>`";
+}
+
#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 9db7b01dd193b..0576809ccdebd 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -223,6 +223,25 @@ LogicalResult TestCopyCountAttr::verify(
return success();
}
+//===----------------------------------------------------------------------===//
+// TestSymbolRefAttr
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+TestSymbolRefAttr::verifySymbolUses(Operation *op,
+ SymbolTableCollection &symbolTable) const {
+ // Verify that the referenced symbol exists
+ if (!symbolTable.lookupNearestSymbolFrom<SymbolOpInterface>(op, getSymbol()))
+ return op->emitOpError()
+ << "TestSymbolRefAttr::verifySymbolUses: '" << getSymbol()
+ << "' does not reference a valid symbol";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Generated Attribute Definitions
+//===----------------------------------------------------------------------===//
+
//===----------------------------------------------------------------------===//
// CopyCountAttr Implementation
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index 0ad5ab641c6d0..d4887ebd33c0a 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -20,10 +20,11 @@
#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
-#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TensorEncoding.h"
// generated files require above includes to come first
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index e6f577d6665c2..9816d9c411cb3 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -85,6 +85,8 @@ exports_files(glob(["include/**/*.td"]))
tbl_outs = {
"include/mlir/IR/" + name + ".h.inc": ["-gen-op-interface-decls"],
"include/mlir/IR/" + name + ".cpp.inc": ["-gen-op-interface-defs"],
+ "include/mlir/IR/" + name + "AttrInterface.h.inc": ["-gen-attr-interface-decls"],
+ "include/mlir/IR/" + name + "AttrInterface.cpp.inc": ["-gen-attr-interface-defs"],
},
tblgen = ":mlir-tblgen",
td_file = "include/mlir/IR/" + name + ".td",
More information about the Mlir-commits
mailing list