[Mlir-commits] [mlir] [mlir][irdl] Lookup symbols near dialects instead of locally (PR #92819)

Théo Degioanni llvmlistbot at llvm.org
Mon May 20 14:27:02 PDT 2024


https://github.com/Moxinilian updated https://github.com/llvm/llvm-project/pull/92819

>From f6ccbe282d692bff7fe9e392e0e2ff5d9644fd4c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Th=C3=A9o=20Degioanni?=
 <theo.degioanni.llvm.deluge062 at simplelogin.fr>
Date: Mon, 20 May 2024 21:55:44 +0100
Subject: [PATCH] lookup symbols near dialects instead of locally

---
 mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h  | 37 ++++++++++++++++++
 mlir/lib/Dialect/IRDL/CMakeLists.txt          |  1 +
 mlir/lib/Dialect/IRDL/IR/IRDL.cpp             |  7 +++-
 mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp          | 10 +++--
 mlir/lib/Dialect/IRDL/IRDLLoading.cpp         |  3 +-
 mlir/lib/Dialect/IRDL/IRDLSymbols.cpp         | 38 +++++++++++++++++++
 mlir/test/Dialect/IRDL/cmath.irdl.mlir        |  8 ++--
 mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir | 12 +++---
 mlir/test/Dialect/IRDL/invalid.irdl.mlir      |  9 +----
 mlir/test/Dialect/IRDL/testd.irdl.mlir        | 12 +++---
 10 files changed, 108 insertions(+), 29 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
 create mode 100644 mlir/lib/Dialect/IRDL/IRDLSymbols.cpp

diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h b/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
new file mode 100644
index 0000000000000..4b7292c054ec2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
@@ -0,0 +1,37 @@
+//===- IRDLSymbols.h - IRDL-related symbol logic ----------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Manages lookup logic for IRDL dialect-absolute symbols.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
+#define MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
+
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+
+namespace mlir {
+namespace irdl {
+
+/// Looks up a symbol from the symbol table containing the source operation's
+/// dialect definition operation. The source operation must be nested within an
+/// IRDL dialect definition operation. This exploits SymbolTableCollection for
+/// better symbol table lookup.
+Operation *lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
+                                   Operation *source, SymbolRefAttr symbol);
+
+/// Looks up a symbol from the symbol table containing the source operation's
+/// dialect definition operation. The source operation must be nested within an
+/// IRDL dialect definition operation.
+Operation *lookupSymbolNearDialect(Operation *source, SymbolRefAttr symbol);
+
+} // namespace irdl
+} // namespace mlir
+
+#endif // MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
diff --git a/mlir/lib/Dialect/IRDL/CMakeLists.txt b/mlir/lib/Dialect/IRDL/CMakeLists.txt
index d25760e5d29bc..db4b98ef5308e 100644
--- a/mlir/lib/Dialect/IRDL/CMakeLists.txt
+++ b/mlir/lib/Dialect/IRDL/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRIRDL
   IR/IRDL.cpp
   IR/IRDLOps.cpp
   IRDLLoading.cpp
+  IRDLSymbols.cpp
   IRDLVerifiers.cpp
 
   DEPENDS
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
index e4728f55b49d7..1f5584fa30c27 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Diagnostics.h"
@@ -132,10 +133,14 @@ LogicalResult BaseOp::verify() {
   return success();
 }
 
+/// Finds whether the provided symbol is an IRDL type or attribute definition.
+/// The source operation must be within a DialectOp.
 static LogicalResult
 checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable,
                              Operation *source, SymbolRefAttr symbol) {
-  Operation *targetOp = symbolTable.lookupNearestSymbolFrom(source, symbol);
+  Operation *targetOp =
+      irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
+
   if (!targetOp)
     return source->emitOpError() << "symbol '" << symbol << "' not found";
 
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
index 0895306b8bce1..7ec3aa2741023 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
 #include "mlir/IR/ValueRange.h"
 #include <optional>
 
@@ -47,8 +48,9 @@ std::unique_ptr<Constraint> BaseOp::getVerifier(
   // Case where the input is a symbol reference.
   // This corresponds to the case where the base is an IRDL type or attribute.
   if (auto baseRef = getBaseRef()) {
+    // The verifier for BaseOp guarantees it is within a dialect.
     Operation *defOp =
-        SymbolTable::lookupNearestSymbolFrom(getOperation(), baseRef.value());
+        irdl::lookupSymbolNearDialect(getOperation(), baseRef.value());
 
     // Type case.
     if (auto typeOp = dyn_cast<TypeOp>(defOp)) {
@@ -99,10 +101,10 @@ std::unique_ptr<Constraint> ParametricOp::getVerifier(
   SmallVector<unsigned> constraints =
       getConstraintIndicesForArgs(getArgs(), valueToConstr);
 
-  // Symbol reference case for the base
+  // Symbol reference case for the base.
+  // The verifier for ParametricOp guarantees it is within a dialect.
   SymbolRefAttr symRef = getBaseType();
-  Operation *defOp =
-      SymbolTable::lookupNearestSymbolFrom(getOperation(), symRef);
+  Operation *defOp = irdl::lookupSymbolNearDialect(getOperation(), symRef);
   if (!defOp) {
     emitError() << symRef << " does not refer to any existing symbol";
     return nullptr;
diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index 5df2b45d8037b..5f623e8845d10 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/IRDL/IRDLLoading.h"
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
 #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
 #include "mlir/Dialect/IRDL/IRDLVerifiers.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -523,7 +524,7 @@ static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> &paramIds,
   // For `irdl.parametric`, we get directly the base from the operation.
   if (auto params = dyn_cast<ParametricOp>(op)) {
     SymbolRefAttr symRef = params.getBaseType();
-    Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef);
+    Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef);
     assert(defOp && "symbol reference should refer to an existing operation");
     paramIrdlOps.insert(defOp);
     return false;
diff --git a/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp b/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp
new file mode 100644
index 0000000000000..ff2136df364d9
--- /dev/null
+++ b/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp
@@ -0,0 +1,38 @@
+//===- IRDLSymbols.cpp - IRDL-related symbol logic --------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+
+using namespace mlir;
+using namespace mlir::irdl;
+
+static Operation *lookupDialectOp(Operation *source) {
+  Operation *dialectOp = source;
+  while (dialectOp && !isa<DialectOp>(dialectOp))
+    dialectOp = dialectOp->getParentOp();
+
+  if (!dialectOp)
+    llvm_unreachable("symbol lookup near dialect must originate from "
+                     "within a dialect definition");
+
+  return dialectOp;
+}
+
+Operation *
+mlir::irdl::lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
+                                    Operation *source, SymbolRefAttr symbol) {
+  return symbolTable.lookupNearestSymbolFrom(
+      lookupDialectOp(source)->getParentOp(), symbol);
+}
+
+Operation *mlir::irdl::lookupSymbolNearDialect(Operation *source,
+                                               SymbolRefAttr symbol) {
+  return SymbolTable::lookupNearestSymbolFrom(
+      lookupDialectOp(source)->getParentOp(), symbol);
+}
diff --git a/mlir/test/Dialect/IRDL/cmath.irdl.mlir b/mlir/test/Dialect/IRDL/cmath.irdl.mlir
index 997af08d24733..0b7e220ceb90c 100644
--- a/mlir/test/Dialect/IRDL/cmath.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/cmath.irdl.mlir
@@ -19,13 +19,13 @@ module {
 
     // CHECK: irdl.operation @norm {
     // CHECK:   %[[v0:[^ ]*]] = irdl.any
-    // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @complex<%[[v0]]>
+    // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v0]]>
     // CHECK:   irdl.operands(%[[v1]])
     // CHECK:   irdl.results(%[[v0]])
     // CHECK: }
     irdl.operation @norm {
       %0 = irdl.any
-      %1 = irdl.parametric @complex<%0>
+      %1 = irdl.parametric @cmath::@complex<%0>
       irdl.operands(%1)
       irdl.results(%0)
     }
@@ -34,7 +34,7 @@ module {
     // CHECK:   %[[v0:[^ ]*]] = irdl.is f32
     // CHECK:   %[[v1:[^ ]*]] = irdl.is f64
     // CHECK:   %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
-    // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @complex<%[[v2]]>
+    // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v2]]>
     // CHECK:   irdl.operands(%[[v3]], %[[v3]])
     // CHECK:   irdl.results(%[[v3]])
     // CHECK: }
@@ -42,7 +42,7 @@ module {
       %0 = irdl.is f32
       %1 = irdl.is f64
       %2 = irdl.any_of(%0, %1)
-      %3 = irdl.parametric @complex<%2>
+      %3 = irdl.parametric @cmath::@complex<%2>
       irdl.operands(%3, %3)
       irdl.results(%3)
     }
diff --git a/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
index db8dfc5cb36ca..cbcc248bf00b1 100644
--- a/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
@@ -6,14 +6,14 @@
 irdl.dialect @testd {
   // CHECK:   irdl.type @self_referencing {
   // CHECK:   %[[v0:[^ ]*]] = irdl.any
-  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @self_referencing<%[[v0]]>
+  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @testd::@self_referencing<%[[v0]]>
   // CHECK:   %[[v2:[^ ]*]] = irdl.is i32
   // CHECK:   %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
   // CHECK:   irdl.parameters(%[[v3]])
   // CHECK: }
   irdl.type @self_referencing {
     %0 = irdl.any
-    %1 = irdl.parametric @self_referencing<%0>
+    %1 = irdl.parametric @testd::@self_referencing<%0>
     %2 = irdl.is i32
     %3 = irdl.any_of(%1, %2)
     irdl.parameters(%3)
@@ -22,13 +22,13 @@ irdl.dialect @testd {
 
   // CHECK:   irdl.type @type1 {
   // CHECK:   %[[v0:[^ ]*]] = irdl.any
-  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @type2<%[[v0]]>
+  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @testd::@type2<%[[v0]]>
   // CHECK:   %[[v2:[^ ]*]] = irdl.is i32
   // CHECK:   %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
   // CHECK:   irdl.parameters(%[[v3]])
   irdl.type @type1 {
     %0 = irdl.any
-    %1 = irdl.parametric @type2<%0>
+    %1 = irdl.parametric @testd::@type2<%0>
     %2 = irdl.is i32
     %3 = irdl.any_of(%1, %2)
     irdl.parameters(%3)
@@ -36,13 +36,13 @@ irdl.dialect @testd {
 
   // CHECK:   irdl.type @type2 {
   // CHECK:   %[[v0:[^ ]*]] = irdl.any
-  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @type1<%[[v0]]>
+  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @testd::@type1<%[[v0]]>
   // CHECK:   %[[v2:[^ ]*]] = irdl.is i32
   // CHECK:   %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
   // CHECK:   irdl.parameters(%[[v3]])
   irdl.type @type2 {
       %0 = irdl.any
-      %1 = irdl.parametric @type1<%0>
+      %1 = irdl.parametric @testd::@type1<%0>
       %2 = irdl.is i32
       %3 = irdl.any_of(%1, %2)
       irdl.parameters(%3)
diff --git a/mlir/test/Dialect/IRDL/invalid.irdl.mlir b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
index f207d31cf158b..93ad619358750 100644
--- a/mlir/test/Dialect/IRDL/invalid.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
@@ -2,8 +2,6 @@
 
 // Testing invalid IRDL IRs
 
-func.func private @foo()
-
 irdl.dialect @testd {
   irdl.type @type {
     // expected-error at +1 {{symbol '@foo' not found}}
@@ -44,15 +42,12 @@ irdl.dialect @testd {
 
 // -----
 
+func.func private @not_a_type_or_attr()
+
 irdl.dialect @invalid_parametric {
   irdl.operation @foo {
     // expected-error at +1 {{symbol '@not_a_type_or_attr' does not refer to a type or attribute definition}}
     %param = irdl.parametric @not_a_type_or_attr<>
     irdl.results(%param)
   }
-
-  irdl.operation @not_a_type_or_attr {
-    %param = irdl.is i1
-    irdl.results(%param)
-  }
 }
diff --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir
index f828d95bdb81d..aeb1a83747ecc 100644
--- a/mlir/test/Dialect/IRDL/testd.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir
@@ -76,20 +76,20 @@ irdl.dialect @testd {
   }
 
   // CHECK: irdl.operation @dyn_type_base {
-  // CHECK:   %[[v1:[^ ]*]] = irdl.base @parametric
+  // CHECK:   %[[v1:[^ ]*]] = irdl.base @testd::@parametric
   // CHECK:   irdl.results(%[[v1]])
   // CHECK: }
   irdl.operation @dyn_type_base {
-    %0 = irdl.base @parametric
+    %0 = irdl.base @testd::@parametric
     irdl.results(%0)
   }
 
   // CHECK: irdl.operation @dyn_attr_base {
-  // CHECK:   %[[v1:[^ ]*]] = irdl.base @parametric_attr
+  // CHECK:   %[[v1:[^ ]*]] = irdl.base @testd::@parametric_attr
   // CHECK:   irdl.attributes {"attr1" = %[[v1]]}
   // CHECK: }
   irdl.operation @dyn_attr_base {
-    %0 = irdl.base @parametric_attr
+    %0 = irdl.base @testd::@parametric_attr
     irdl.attributes {"attr1" = %0}
   }
 
@@ -115,14 +115,14 @@ irdl.dialect @testd {
   // CHECK:   %[[v0:[^ ]*]] = irdl.is i32
   // CHECK:   %[[v1:[^ ]*]] = irdl.is i64
   // CHECK:   %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
-  // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @parametric<%[[v2]]>
+  // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @testd::@parametric<%[[v2]]>
   // CHECK:   irdl.results(%[[v3]])
   // CHECK: }
   irdl.operation @dynparams {
     %0 = irdl.is i32
     %1 = irdl.is i64
     %2 = irdl.any_of(%0, %1)
-    %3 = irdl.parametric @parametric<%2>
+    %3 = irdl.parametric @testd::@parametric<%2>
     irdl.results(%3)
   }
 



More information about the Mlir-commits mailing list