[Mlir-commits] [mlir] [mlir][irdl] Lookup symbols near dialects instead of locally (PR #92819)
Théo Degioanni
llvmlistbot at llvm.org
Mon May 20 13:58:57 PDT 2024
https://github.com/Moxinilian created https://github.com/llvm/llvm-project/pull/92819
Because symbols cannot refer to operations outside of their symbol tables, it was impossible to refer to operations outside of the dialect currently being defined. This PR modifies the lookup logic to happen relative to the symbol table containing the dialect-defining operations. This is a bit of hack but should unblock the situation here.
>From f6ccbe282d692bff7fe9e392e0e2ff5d9644fd4c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Th=C3=A9o=20Degioanni?=
<theo.degioanni.llvm.deluge062 at simplelogin.fr>
Date: Mon, 20 May 2024 21:55:44 +0100
Subject: [PATCH] lookup symbols near dialects instead of locally
---
mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h | 37 ++++++++++++++++++
mlir/lib/Dialect/IRDL/CMakeLists.txt | 1 +
mlir/lib/Dialect/IRDL/IR/IRDL.cpp | 7 +++-
mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp | 10 +++--
mlir/lib/Dialect/IRDL/IRDLLoading.cpp | 3 +-
mlir/lib/Dialect/IRDL/IRDLSymbols.cpp | 38 +++++++++++++++++++
mlir/test/Dialect/IRDL/cmath.irdl.mlir | 8 ++--
mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir | 12 +++---
mlir/test/Dialect/IRDL/invalid.irdl.mlir | 9 +----
mlir/test/Dialect/IRDL/testd.irdl.mlir | 12 +++---
10 files changed, 108 insertions(+), 29 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
create mode 100644 mlir/lib/Dialect/IRDL/IRDLSymbols.cpp
diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h b/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
new file mode 100644
index 0000000000000..4b7292c054ec2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
@@ -0,0 +1,37 @@
+//===- IRDLSymbols.h - IRDL-related symbol logic ----------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Manages lookup logic for IRDL dialect-absolute symbols.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
+#define MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
+
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+
+namespace mlir {
+namespace irdl {
+
+/// Looks up a symbol from the symbol table containing the source operation's
+/// dialect definition operation. The source operation must be nested within an
+/// IRDL dialect definition operation. This exploits SymbolTableCollection for
+/// better symbol table lookup.
+Operation *lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
+ Operation *source, SymbolRefAttr symbol);
+
+/// Looks up a symbol from the symbol table containing the source operation's
+/// dialect definition operation. The source operation must be nested within an
+/// IRDL dialect definition operation.
+Operation *lookupSymbolNearDialect(Operation *source, SymbolRefAttr symbol);
+
+} // namespace irdl
+} // namespace mlir
+
+#endif // MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
diff --git a/mlir/lib/Dialect/IRDL/CMakeLists.txt b/mlir/lib/Dialect/IRDL/CMakeLists.txt
index d25760e5d29bc..db4b98ef5308e 100644
--- a/mlir/lib/Dialect/IRDL/CMakeLists.txt
+++ b/mlir/lib/Dialect/IRDL/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRIRDL
IR/IRDL.cpp
IR/IRDLOps.cpp
IRDLLoading.cpp
+ IRDLSymbols.cpp
IRDLVerifiers.cpp
DEPENDS
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
index e4728f55b49d7..1f5584fa30c27 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
@@ -132,10 +133,14 @@ LogicalResult BaseOp::verify() {
return success();
}
+/// Finds whether the provided symbol is an IRDL type or attribute definition.
+/// The source operation must be within a DialectOp.
static LogicalResult
checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable,
Operation *source, SymbolRefAttr symbol) {
- Operation *targetOp = symbolTable.lookupNearestSymbolFrom(source, symbol);
+ Operation *targetOp =
+ irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
+
if (!targetOp)
return source->emitOpError() << "symbol '" << symbol << "' not found";
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
index 0895306b8bce1..7ec3aa2741023 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
#include "mlir/IR/ValueRange.h"
#include <optional>
@@ -47,8 +48,9 @@ std::unique_ptr<Constraint> BaseOp::getVerifier(
// Case where the input is a symbol reference.
// This corresponds to the case where the base is an IRDL type or attribute.
if (auto baseRef = getBaseRef()) {
+ // The verifier for BaseOp guarantees it is within a dialect.
Operation *defOp =
- SymbolTable::lookupNearestSymbolFrom(getOperation(), baseRef.value());
+ irdl::lookupSymbolNearDialect(getOperation(), baseRef.value());
// Type case.
if (auto typeOp = dyn_cast<TypeOp>(defOp)) {
@@ -99,10 +101,10 @@ std::unique_ptr<Constraint> ParametricOp::getVerifier(
SmallVector<unsigned> constraints =
getConstraintIndicesForArgs(getArgs(), valueToConstr);
- // Symbol reference case for the base
+ // Symbol reference case for the base.
+ // The verifier for ParametricOp guarantees it is within a dialect.
SymbolRefAttr symRef = getBaseType();
- Operation *defOp =
- SymbolTable::lookupNearestSymbolFrom(getOperation(), symRef);
+ Operation *defOp = irdl::lookupSymbolNearDialect(getOperation(), symRef);
if (!defOp) {
emitError() << symRef << " does not refer to any existing symbol";
return nullptr;
diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index 5df2b45d8037b..5f623e8845d10 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/IRDL/IRDLLoading.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
@@ -523,7 +524,7 @@ static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> ¶mIds,
// For `irdl.parametric`, we get directly the base from the operation.
if (auto params = dyn_cast<ParametricOp>(op)) {
SymbolRefAttr symRef = params.getBaseType();
- Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef);
+ Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef);
assert(defOp && "symbol reference should refer to an existing operation");
paramIrdlOps.insert(defOp);
return false;
diff --git a/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp b/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp
new file mode 100644
index 0000000000000..ff2136df364d9
--- /dev/null
+++ b/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp
@@ -0,0 +1,38 @@
+//===- IRDLSymbols.cpp - IRDL-related symbol logic --------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+
+using namespace mlir;
+using namespace mlir::irdl;
+
+static Operation *lookupDialectOp(Operation *source) {
+ Operation *dialectOp = source;
+ while (dialectOp && !isa<DialectOp>(dialectOp))
+ dialectOp = dialectOp->getParentOp();
+
+ if (!dialectOp)
+ llvm_unreachable("symbol lookup near dialect must originate from "
+ "within a dialect definition");
+
+ return dialectOp;
+}
+
+Operation *
+mlir::irdl::lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
+ Operation *source, SymbolRefAttr symbol) {
+ return symbolTable.lookupNearestSymbolFrom(
+ lookupDialectOp(source)->getParentOp(), symbol);
+}
+
+Operation *mlir::irdl::lookupSymbolNearDialect(Operation *source,
+ SymbolRefAttr symbol) {
+ return SymbolTable::lookupNearestSymbolFrom(
+ lookupDialectOp(source)->getParentOp(), symbol);
+}
diff --git a/mlir/test/Dialect/IRDL/cmath.irdl.mlir b/mlir/test/Dialect/IRDL/cmath.irdl.mlir
index 997af08d24733..0b7e220ceb90c 100644
--- a/mlir/test/Dialect/IRDL/cmath.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/cmath.irdl.mlir
@@ -19,13 +19,13 @@ module {
// CHECK: irdl.operation @norm {
// CHECK: %[[v0:[^ ]*]] = irdl.any
- // CHECK: %[[v1:[^ ]*]] = irdl.parametric @complex<%[[v0]]>
+ // CHECK: %[[v1:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v0]]>
// CHECK: irdl.operands(%[[v1]])
// CHECK: irdl.results(%[[v0]])
// CHECK: }
irdl.operation @norm {
%0 = irdl.any
- %1 = irdl.parametric @complex<%0>
+ %1 = irdl.parametric @cmath::@complex<%0>
irdl.operands(%1)
irdl.results(%0)
}
@@ -34,7 +34,7 @@ module {
// CHECK: %[[v0:[^ ]*]] = irdl.is f32
// CHECK: %[[v1:[^ ]*]] = irdl.is f64
// CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
- // CHECK: %[[v3:[^ ]*]] = irdl.parametric @complex<%[[v2]]>
+ // CHECK: %[[v3:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v2]]>
// CHECK: irdl.operands(%[[v3]], %[[v3]])
// CHECK: irdl.results(%[[v3]])
// CHECK: }
@@ -42,7 +42,7 @@ module {
%0 = irdl.is f32
%1 = irdl.is f64
%2 = irdl.any_of(%0, %1)
- %3 = irdl.parametric @complex<%2>
+ %3 = irdl.parametric @cmath::@complex<%2>
irdl.operands(%3, %3)
irdl.results(%3)
}
diff --git a/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
index db8dfc5cb36ca..cbcc248bf00b1 100644
--- a/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
@@ -6,14 +6,14 @@
irdl.dialect @testd {
// CHECK: irdl.type @self_referencing {
// CHECK: %[[v0:[^ ]*]] = irdl.any
- // CHECK: %[[v1:[^ ]*]] = irdl.parametric @self_referencing<%[[v0]]>
+ // CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@self_referencing<%[[v0]]>
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
// CHECK: irdl.parameters(%[[v3]])
// CHECK: }
irdl.type @self_referencing {
%0 = irdl.any
- %1 = irdl.parametric @self_referencing<%0>
+ %1 = irdl.parametric @testd::@self_referencing<%0>
%2 = irdl.is i32
%3 = irdl.any_of(%1, %2)
irdl.parameters(%3)
@@ -22,13 +22,13 @@ irdl.dialect @testd {
// CHECK: irdl.type @type1 {
// CHECK: %[[v0:[^ ]*]] = irdl.any
- // CHECK: %[[v1:[^ ]*]] = irdl.parametric @type2<%[[v0]]>
+ // CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@type2<%[[v0]]>
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
// CHECK: irdl.parameters(%[[v3]])
irdl.type @type1 {
%0 = irdl.any
- %1 = irdl.parametric @type2<%0>
+ %1 = irdl.parametric @testd::@type2<%0>
%2 = irdl.is i32
%3 = irdl.any_of(%1, %2)
irdl.parameters(%3)
@@ -36,13 +36,13 @@ irdl.dialect @testd {
// CHECK: irdl.type @type2 {
// CHECK: %[[v0:[^ ]*]] = irdl.any
- // CHECK: %[[v1:[^ ]*]] = irdl.parametric @type1<%[[v0]]>
+ // CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@type1<%[[v0]]>
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
// CHECK: irdl.parameters(%[[v3]])
irdl.type @type2 {
%0 = irdl.any
- %1 = irdl.parametric @type1<%0>
+ %1 = irdl.parametric @testd::@type1<%0>
%2 = irdl.is i32
%3 = irdl.any_of(%1, %2)
irdl.parameters(%3)
diff --git a/mlir/test/Dialect/IRDL/invalid.irdl.mlir b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
index f207d31cf158b..93ad619358750 100644
--- a/mlir/test/Dialect/IRDL/invalid.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
@@ -2,8 +2,6 @@
// Testing invalid IRDL IRs
-func.func private @foo()
-
irdl.dialect @testd {
irdl.type @type {
// expected-error at +1 {{symbol '@foo' not found}}
@@ -44,15 +42,12 @@ irdl.dialect @testd {
// -----
+func.func private @not_a_type_or_attr()
+
irdl.dialect @invalid_parametric {
irdl.operation @foo {
// expected-error at +1 {{symbol '@not_a_type_or_attr' does not refer to a type or attribute definition}}
%param = irdl.parametric @not_a_type_or_attr<>
irdl.results(%param)
}
-
- irdl.operation @not_a_type_or_attr {
- %param = irdl.is i1
- irdl.results(%param)
- }
}
diff --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir
index f828d95bdb81d..aeb1a83747ecc 100644
--- a/mlir/test/Dialect/IRDL/testd.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir
@@ -76,20 +76,20 @@ irdl.dialect @testd {
}
// CHECK: irdl.operation @dyn_type_base {
- // CHECK: %[[v1:[^ ]*]] = irdl.base @parametric
+ // CHECK: %[[v1:[^ ]*]] = irdl.base @testd::@parametric
// CHECK: irdl.results(%[[v1]])
// CHECK: }
irdl.operation @dyn_type_base {
- %0 = irdl.base @parametric
+ %0 = irdl.base @testd::@parametric
irdl.results(%0)
}
// CHECK: irdl.operation @dyn_attr_base {
- // CHECK: %[[v1:[^ ]*]] = irdl.base @parametric_attr
+ // CHECK: %[[v1:[^ ]*]] = irdl.base @testd::@parametric_attr
// CHECK: irdl.attributes {"attr1" = %[[v1]]}
// CHECK: }
irdl.operation @dyn_attr_base {
- %0 = irdl.base @parametric_attr
+ %0 = irdl.base @testd::@parametric_attr
irdl.attributes {"attr1" = %0}
}
@@ -115,14 +115,14 @@ irdl.dialect @testd {
// CHECK: %[[v0:[^ ]*]] = irdl.is i32
// CHECK: %[[v1:[^ ]*]] = irdl.is i64
// CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
- // CHECK: %[[v3:[^ ]*]] = irdl.parametric @parametric<%[[v2]]>
+ // CHECK: %[[v3:[^ ]*]] = irdl.parametric @testd::@parametric<%[[v2]]>
// CHECK: irdl.results(%[[v3]])
// CHECK: }
irdl.operation @dynparams {
%0 = irdl.is i32
%1 = irdl.is i64
%2 = irdl.any_of(%0, %1)
- %3 = irdl.parametric @parametric<%2>
+ %3 = irdl.parametric @testd::@parametric<%2>
irdl.results(%3)
}
More information about the Mlir-commits
mailing list