[Mlir-commits] [mlir] [mlir][irdl] Lookup symbols near dialects instead of locally (PR #92819)
Théo Degioanni
llvmlistbot at llvm.org
Mon May 20 14:27:02 PDT 2024
https://github.com/Moxinilian updated https://github.com/llvm/llvm-project/pull/92819
>From f6ccbe282d692bff7fe9e392e0e2ff5d9644fd4c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Th=C3=A9o=20Degioanni?=
<theo.degioanni.llvm.deluge062 at simplelogin.fr>
Date: Mon, 20 May 2024 21:55:44 +0100
Subject: [PATCH] lookup symbols near dialects instead of locally
---
mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h | 37 ++++++++++++++++++
mlir/lib/Dialect/IRDL/CMakeLists.txt | 1 +
mlir/lib/Dialect/IRDL/IR/IRDL.cpp | 7 +++-
mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp | 10 +++--
mlir/lib/Dialect/IRDL/IRDLLoading.cpp | 3 +-
mlir/lib/Dialect/IRDL/IRDLSymbols.cpp | 38 +++++++++++++++++++
mlir/test/Dialect/IRDL/cmath.irdl.mlir | 8 ++--
mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir | 12 +++---
mlir/test/Dialect/IRDL/invalid.irdl.mlir | 9 +----
mlir/test/Dialect/IRDL/testd.irdl.mlir | 12 +++---
10 files changed, 108 insertions(+), 29 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
create mode 100644 mlir/lib/Dialect/IRDL/IRDLSymbols.cpp
diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h b/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
new file mode 100644
index 0000000000000..4b7292c054ec2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
@@ -0,0 +1,37 @@
+//===- IRDLSymbols.h - IRDL-related symbol logic ----------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Manages lookup logic for IRDL dialect-absolute symbols.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
+#define MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
+
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+
+namespace mlir {
+namespace irdl {
+
+/// Looks up a symbol from the symbol table containing the source operation's
+/// dialect definition operation. The source operation must be nested within an
+/// IRDL dialect definition operation. This exploits SymbolTableCollection for
+/// better symbol table lookup.
+Operation *lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
+ Operation *source, SymbolRefAttr symbol);
+
+/// Looks up a symbol from the symbol table containing the source operation's
+/// dialect definition operation. The source operation must be nested within an
+/// IRDL dialect definition operation.
+Operation *lookupSymbolNearDialect(Operation *source, SymbolRefAttr symbol);
+
+} // namespace irdl
+} // namespace mlir
+
+#endif // MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
diff --git a/mlir/lib/Dialect/IRDL/CMakeLists.txt b/mlir/lib/Dialect/IRDL/CMakeLists.txt
index d25760e5d29bc..db4b98ef5308e 100644
--- a/mlir/lib/Dialect/IRDL/CMakeLists.txt
+++ b/mlir/lib/Dialect/IRDL/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRIRDL
IR/IRDL.cpp
IR/IRDLOps.cpp
IRDLLoading.cpp
+ IRDLSymbols.cpp
IRDLVerifiers.cpp
DEPENDS
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
index e4728f55b49d7..1f5584fa30c27 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
@@ -132,10 +133,14 @@ LogicalResult BaseOp::verify() {
return success();
}
+/// Finds whether the provided symbol is an IRDL type or attribute definition.
+/// The source operation must be within a DialectOp.
static LogicalResult
checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable,
Operation *source, SymbolRefAttr symbol) {
- Operation *targetOp = symbolTable.lookupNearestSymbolFrom(source, symbol);
+ Operation *targetOp =
+ irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
+
if (!targetOp)
return source->emitOpError() << "symbol '" << symbol << "' not found";
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
index 0895306b8bce1..7ec3aa2741023 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
#include "mlir/IR/ValueRange.h"
#include <optional>
@@ -47,8 +48,9 @@ std::unique_ptr<Constraint> BaseOp::getVerifier(
// Case where the input is a symbol reference.
// This corresponds to the case where the base is an IRDL type or attribute.
if (auto baseRef = getBaseRef()) {
+ // The verifier for BaseOp guarantees it is within a dialect.
Operation *defOp =
- SymbolTable::lookupNearestSymbolFrom(getOperation(), baseRef.value());
+ irdl::lookupSymbolNearDialect(getOperation(), baseRef.value());
// Type case.
if (auto typeOp = dyn_cast<TypeOp>(defOp)) {
@@ -99,10 +101,10 @@ std::unique_ptr<Constraint> ParametricOp::getVerifier(
SmallVector<unsigned> constraints =
getConstraintIndicesForArgs(getArgs(), valueToConstr);
- // Symbol reference case for the base
+ // Symbol reference case for the base.
+ // The verifier for ParametricOp guarantees it is within a dialect.
SymbolRefAttr symRef = getBaseType();
- Operation *defOp =
- SymbolTable::lookupNearestSymbolFrom(getOperation(), symRef);
+ Operation *defOp = irdl::lookupSymbolNearDialect(getOperation(), symRef);
if (!defOp) {
emitError() << symRef << " does not refer to any existing symbol";
return nullptr;
diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index 5df2b45d8037b..5f623e8845d10 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/IRDL/IRDLLoading.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
@@ -523,7 +524,7 @@ static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> ¶mIds,
// For `irdl.parametric`, we get directly the base from the operation.
if (auto params = dyn_cast<ParametricOp>(op)) {
SymbolRefAttr symRef = params.getBaseType();
- Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef);
+ Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef);
assert(defOp && "symbol reference should refer to an existing operation");
paramIrdlOps.insert(defOp);
return false;
diff --git a/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp b/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp
new file mode 100644
index 0000000000000..ff2136df364d9
--- /dev/null
+++ b/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp
@@ -0,0 +1,38 @@
+//===- IRDLSymbols.cpp - IRDL-related symbol logic --------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+
+using namespace mlir;
+using namespace mlir::irdl;
+
+static Operation *lookupDialectOp(Operation *source) {
+ Operation *dialectOp = source;
+ while (dialectOp && !isa<DialectOp>(dialectOp))
+ dialectOp = dialectOp->getParentOp();
+
+ if (!dialectOp)
+ llvm_unreachable("symbol lookup near dialect must originate from "
+ "within a dialect definition");
+
+ return dialectOp;
+}
+
+Operation *
+mlir::irdl::lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
+ Operation *source, SymbolRefAttr symbol) {
+ return symbolTable.lookupNearestSymbolFrom(
+ lookupDialectOp(source)->getParentOp(), symbol);
+}
+
+Operation *mlir::irdl::lookupSymbolNearDialect(Operation *source,
+ SymbolRefAttr symbol) {
+ return SymbolTable::lookupNearestSymbolFrom(
+ lookupDialectOp(source)->getParentOp(), symbol);
+}
diff --git a/mlir/test/Dialect/IRDL/cmath.irdl.mlir b/mlir/test/Dialect/IRDL/cmath.irdl.mlir
index 997af08d24733..0b7e220ceb90c 100644
--- a/mlir/test/Dialect/IRDL/cmath.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/cmath.irdl.mlir
@@ -19,13 +19,13 @@ module {
// CHECK: irdl.operation @norm {
// CHECK: %[[v0:[^ ]*]] = irdl.any
- // CHECK: %[[v1:[^ ]*]] = irdl.parametric @complex<%[[v0]]>
+ // CHECK: %[[v1:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v0]]>
// CHECK: irdl.operands(%[[v1]])
// CHECK: irdl.results(%[[v0]])
// CHECK: }
irdl.operation @norm {
%0 = irdl.any
- %1 = irdl.parametric @complex<%0>
+ %1 = irdl.parametric @cmath::@complex<%0>
irdl.operands(%1)
irdl.results(%0)
}
@@ -34,7 +34,7 @@ module {
// CHECK: %[[v0:[^ ]*]] = irdl.is f32
// CHECK: %[[v1:[^ ]*]] = irdl.is f64
// CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
- // CHECK: %[[v3:[^ ]*]] = irdl.parametric @complex<%[[v2]]>
+ // CHECK: %[[v3:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v2]]>
// CHECK: irdl.operands(%[[v3]], %[[v3]])
// CHECK: irdl.results(%[[v3]])
// CHECK: }
@@ -42,7 +42,7 @@ module {
%0 = irdl.is f32
%1 = irdl.is f64
%2 = irdl.any_of(%0, %1)
- %3 = irdl.parametric @complex<%2>
+ %3 = irdl.parametric @cmath::@complex<%2>
irdl.operands(%3, %3)
irdl.results(%3)
}
diff --git a/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
index db8dfc5cb36ca..cbcc248bf00b1 100644
--- a/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
@@ -6,14 +6,14 @@
irdl.dialect @testd {
// CHECK: irdl.type @self_referencing {
// CHECK: %[[v0:[^ ]*]] = irdl.any
- // CHECK: %[[v1:[^ ]*]] = irdl.parametric @self_referencing<%[[v0]]>
+ // CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@self_referencing<%[[v0]]>
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
// CHECK: irdl.parameters(%[[v3]])
// CHECK: }
irdl.type @self_referencing {
%0 = irdl.any
- %1 = irdl.parametric @self_referencing<%0>
+ %1 = irdl.parametric @testd::@self_referencing<%0>
%2 = irdl.is i32
%3 = irdl.any_of(%1, %2)
irdl.parameters(%3)
@@ -22,13 +22,13 @@ irdl.dialect @testd {
// CHECK: irdl.type @type1 {
// CHECK: %[[v0:[^ ]*]] = irdl.any
- // CHECK: %[[v1:[^ ]*]] = irdl.parametric @type2<%[[v0]]>
+ // CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@type2<%[[v0]]>
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
// CHECK: irdl.parameters(%[[v3]])
irdl.type @type1 {
%0 = irdl.any
- %1 = irdl.parametric @type2<%0>
+ %1 = irdl.parametric @testd::@type2<%0>
%2 = irdl.is i32
%3 = irdl.any_of(%1, %2)
irdl.parameters(%3)
@@ -36,13 +36,13 @@ irdl.dialect @testd {
// CHECK: irdl.type @type2 {
// CHECK: %[[v0:[^ ]*]] = irdl.any
- // CHECK: %[[v1:[^ ]*]] = irdl.parametric @type1<%[[v0]]>
+ // CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@type1<%[[v0]]>
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
// CHECK: irdl.parameters(%[[v3]])
irdl.type @type2 {
%0 = irdl.any
- %1 = irdl.parametric @type1<%0>
+ %1 = irdl.parametric @testd::@type1<%0>
%2 = irdl.is i32
%3 = irdl.any_of(%1, %2)
irdl.parameters(%3)
diff --git a/mlir/test/Dialect/IRDL/invalid.irdl.mlir b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
index f207d31cf158b..93ad619358750 100644
--- a/mlir/test/Dialect/IRDL/invalid.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
@@ -2,8 +2,6 @@
// Testing invalid IRDL IRs
-func.func private @foo()
-
irdl.dialect @testd {
irdl.type @type {
// expected-error at +1 {{symbol '@foo' not found}}
@@ -44,15 +42,12 @@ irdl.dialect @testd {
// -----
+func.func private @not_a_type_or_attr()
+
irdl.dialect @invalid_parametric {
irdl.operation @foo {
// expected-error at +1 {{symbol '@not_a_type_or_attr' does not refer to a type or attribute definition}}
%param = irdl.parametric @not_a_type_or_attr<>
irdl.results(%param)
}
-
- irdl.operation @not_a_type_or_attr {
- %param = irdl.is i1
- irdl.results(%param)
- }
}
diff --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir
index f828d95bdb81d..aeb1a83747ecc 100644
--- a/mlir/test/Dialect/IRDL/testd.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir
@@ -76,20 +76,20 @@ irdl.dialect @testd {
}
// CHECK: irdl.operation @dyn_type_base {
- // CHECK: %[[v1:[^ ]*]] = irdl.base @parametric
+ // CHECK: %[[v1:[^ ]*]] = irdl.base @testd::@parametric
// CHECK: irdl.results(%[[v1]])
// CHECK: }
irdl.operation @dyn_type_base {
- %0 = irdl.base @parametric
+ %0 = irdl.base @testd::@parametric
irdl.results(%0)
}
// CHECK: irdl.operation @dyn_attr_base {
- // CHECK: %[[v1:[^ ]*]] = irdl.base @parametric_attr
+ // CHECK: %[[v1:[^ ]*]] = irdl.base @testd::@parametric_attr
// CHECK: irdl.attributes {"attr1" = %[[v1]]}
// CHECK: }
irdl.operation @dyn_attr_base {
- %0 = irdl.base @parametric_attr
+ %0 = irdl.base @testd::@parametric_attr
irdl.attributes {"attr1" = %0}
}
@@ -115,14 +115,14 @@ irdl.dialect @testd {
// CHECK: %[[v0:[^ ]*]] = irdl.is i32
// CHECK: %[[v1:[^ ]*]] = irdl.is i64
// CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
- // CHECK: %[[v3:[^ ]*]] = irdl.parametric @parametric<%[[v2]]>
+ // CHECK: %[[v3:[^ ]*]] = irdl.parametric @testd::@parametric<%[[v2]]>
// CHECK: irdl.results(%[[v3]])
// CHECK: }
irdl.operation @dynparams {
%0 = irdl.is i32
%1 = irdl.is i64
%2 = irdl.any_of(%0, %1)
- %3 = irdl.parametric @parametric<%2>
+ %3 = irdl.parametric @testd::@parametric<%2>
irdl.results(%3)
}
More information about the Mlir-commits
mailing list