[Mlir-commits] [mlir] b86a9c5 - [mlir][irdl] Lookup symbols near dialects instead of locally (#92819)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 31 01:15:53 PDT 2024


Author: Théo Degioanni
Date: 2024-05-31T09:15:50+01:00
New Revision: b86a9c5bf2fab0408a3d549995d6e2449f71a16d

URL: https://github.com/llvm/llvm-project/commit/b86a9c5bf2fab0408a3d549995d6e2449f71a16d
DIFF: https://github.com/llvm/llvm-project/commit/b86a9c5bf2fab0408a3d549995d6e2449f71a16d.diff

LOG: [mlir][irdl] Lookup symbols near dialects instead of locally (#92819)

Because symbols cannot refer to operations outside of their symbol
tables, it was impossible to refer to operations outside of the dialect
currently being defined. This PR modifies the lookup logic to happen
relative to the symbol table containing the dialect-defining operations.
This is a bit of hack but should unblock the situation here.

Added: 
    mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
    mlir/lib/Dialect/IRDL/IRDLSymbols.cpp

Modified: 
    mlir/lib/Dialect/IRDL/CMakeLists.txt
    mlir/lib/Dialect/IRDL/IR/IRDL.cpp
    mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
    mlir/lib/Dialect/IRDL/IRDLLoading.cpp
    mlir/test/Dialect/IRDL/cmath.irdl.mlir
    mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
    mlir/test/Dialect/IRDL/invalid.irdl.mlir
    mlir/test/Dialect/IRDL/testd.irdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h b/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
new file mode 100644
index 0000000000000..4b7292c054ec2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
@@ -0,0 +1,37 @@
+//===- IRDLSymbols.h - IRDL-related symbol logic ----------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Manages lookup logic for IRDL dialect-absolute symbols.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
+#define MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
+
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+
+namespace mlir {
+namespace irdl {
+
+/// Looks up a symbol from the symbol table containing the source operation's
+/// dialect definition operation. The source operation must be nested within an
+/// IRDL dialect definition operation. This exploits SymbolTableCollection for
+/// better symbol table lookup.
+Operation *lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
+                                   Operation *source, SymbolRefAttr symbol);
+
+/// Looks up a symbol from the symbol table containing the source operation's
+/// dialect definition operation. The source operation must be nested within an
+/// IRDL dialect definition operation.
+Operation *lookupSymbolNearDialect(Operation *source, SymbolRefAttr symbol);
+
+} // namespace irdl
+} // namespace mlir
+
+#endif // MLIR_DIALECT_IRDL_IRDLSYMBOLS_H

diff  --git a/mlir/lib/Dialect/IRDL/CMakeLists.txt b/mlir/lib/Dialect/IRDL/CMakeLists.txt
index d25760e5d29bc..db4b98ef5308e 100644
--- a/mlir/lib/Dialect/IRDL/CMakeLists.txt
+++ b/mlir/lib/Dialect/IRDL/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRIRDL
   IR/IRDL.cpp
   IR/IRDLOps.cpp
   IRDLLoading.cpp
+  IRDLSymbols.cpp
   IRDLVerifiers.cpp
 
   DEPENDS

diff  --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
index e4728f55b49d7..1f5584fa30c27 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Diagnostics.h"
@@ -132,10 +133,14 @@ LogicalResult BaseOp::verify() {
   return success();
 }
 
+/// Finds whether the provided symbol is an IRDL type or attribute definition.
+/// The source operation must be within a DialectOp.
 static LogicalResult
 checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable,
                              Operation *source, SymbolRefAttr symbol) {
-  Operation *targetOp = symbolTable.lookupNearestSymbolFrom(source, symbol);
+  Operation *targetOp =
+      irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
+
   if (!targetOp)
     return source->emitOpError() << "symbol '" << symbol << "' not found";
 

diff  --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
index 0895306b8bce1..7ec3aa2741023 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
 #include "mlir/IR/ValueRange.h"
 #include <optional>
 
@@ -47,8 +48,9 @@ std::unique_ptr<Constraint> BaseOp::getVerifier(
   // Case where the input is a symbol reference.
   // This corresponds to the case where the base is an IRDL type or attribute.
   if (auto baseRef = getBaseRef()) {
+    // The verifier for BaseOp guarantees it is within a dialect.
     Operation *defOp =
-        SymbolTable::lookupNearestSymbolFrom(getOperation(), baseRef.value());
+        irdl::lookupSymbolNearDialect(getOperation(), baseRef.value());
 
     // Type case.
     if (auto typeOp = dyn_cast<TypeOp>(defOp)) {
@@ -99,10 +101,10 @@ std::unique_ptr<Constraint> ParametricOp::getVerifier(
   SmallVector<unsigned> constraints =
       getConstraintIndicesForArgs(getArgs(), valueToConstr);
 
-  // Symbol reference case for the base
+  // Symbol reference case for the base.
+  // The verifier for ParametricOp guarantees it is within a dialect.
   SymbolRefAttr symRef = getBaseType();
-  Operation *defOp =
-      SymbolTable::lookupNearestSymbolFrom(getOperation(), symRef);
+  Operation *defOp = irdl::lookupSymbolNearDialect(getOperation(), symRef);
   if (!defOp) {
     emitError() << symRef << " does not refer to any existing symbol";
     return nullptr;

diff  --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index 5df2b45d8037b..5f623e8845d10 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/IRDL/IRDLLoading.h"
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
 #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
 #include "mlir/Dialect/IRDL/IRDLVerifiers.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -523,7 +524,7 @@ static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> &paramIds,
   // For `irdl.parametric`, we get directly the base from the operation.
   if (auto params = dyn_cast<ParametricOp>(op)) {
     SymbolRefAttr symRef = params.getBaseType();
-    Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef);
+    Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef);
     assert(defOp && "symbol reference should refer to an existing operation");
     paramIrdlOps.insert(defOp);
     return false;

diff  --git a/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp b/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp
new file mode 100644
index 0000000000000..ff2136df364d9
--- /dev/null
+++ b/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp
@@ -0,0 +1,38 @@
+//===- IRDLSymbols.cpp - IRDL-related symbol logic --------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+
+using namespace mlir;
+using namespace mlir::irdl;
+
+static Operation *lookupDialectOp(Operation *source) {
+  Operation *dialectOp = source;
+  while (dialectOp && !isa<DialectOp>(dialectOp))
+    dialectOp = dialectOp->getParentOp();
+
+  if (!dialectOp)
+    llvm_unreachable("symbol lookup near dialect must originate from "
+                     "within a dialect definition");
+
+  return dialectOp;
+}
+
+Operation *
+mlir::irdl::lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
+                                    Operation *source, SymbolRefAttr symbol) {
+  return symbolTable.lookupNearestSymbolFrom(
+      lookupDialectOp(source)->getParentOp(), symbol);
+}
+
+Operation *mlir::irdl::lookupSymbolNearDialect(Operation *source,
+                                               SymbolRefAttr symbol) {
+  return SymbolTable::lookupNearestSymbolFrom(
+      lookupDialectOp(source)->getParentOp(), symbol);
+}

diff  --git a/mlir/test/Dialect/IRDL/cmath.irdl.mlir b/mlir/test/Dialect/IRDL/cmath.irdl.mlir
index 997af08d24733..0b7e220ceb90c 100644
--- a/mlir/test/Dialect/IRDL/cmath.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/cmath.irdl.mlir
@@ -19,13 +19,13 @@ module {
 
     // CHECK: irdl.operation @norm {
     // CHECK:   %[[v0:[^ ]*]] = irdl.any
-    // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @complex<%[[v0]]>
+    // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v0]]>
     // CHECK:   irdl.operands(%[[v1]])
     // CHECK:   irdl.results(%[[v0]])
     // CHECK: }
     irdl.operation @norm {
       %0 = irdl.any
-      %1 = irdl.parametric @complex<%0>
+      %1 = irdl.parametric @cmath::@complex<%0>
       irdl.operands(%1)
       irdl.results(%0)
     }
@@ -34,7 +34,7 @@ module {
     // CHECK:   %[[v0:[^ ]*]] = irdl.is f32
     // CHECK:   %[[v1:[^ ]*]] = irdl.is f64
     // CHECK:   %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
-    // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @complex<%[[v2]]>
+    // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v2]]>
     // CHECK:   irdl.operands(%[[v3]], %[[v3]])
     // CHECK:   irdl.results(%[[v3]])
     // CHECK: }
@@ -42,7 +42,7 @@ module {
       %0 = irdl.is f32
       %1 = irdl.is f64
       %2 = irdl.any_of(%0, %1)
-      %3 = irdl.parametric @complex<%2>
+      %3 = irdl.parametric @cmath::@complex<%2>
       irdl.operands(%3, %3)
       irdl.results(%3)
     }

diff  --git a/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
index db8dfc5cb36ca..cbcc248bf00b1 100644
--- a/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
@@ -6,14 +6,14 @@
 irdl.dialect @testd {
   // CHECK:   irdl.type @self_referencing {
   // CHECK:   %[[v0:[^ ]*]] = irdl.any
-  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @self_referencing<%[[v0]]>
+  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @testd::@self_referencing<%[[v0]]>
   // CHECK:   %[[v2:[^ ]*]] = irdl.is i32
   // CHECK:   %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
   // CHECK:   irdl.parameters(%[[v3]])
   // CHECK: }
   irdl.type @self_referencing {
     %0 = irdl.any
-    %1 = irdl.parametric @self_referencing<%0>
+    %1 = irdl.parametric @testd::@self_referencing<%0>
     %2 = irdl.is i32
     %3 = irdl.any_of(%1, %2)
     irdl.parameters(%3)
@@ -22,13 +22,13 @@ irdl.dialect @testd {
 
   // CHECK:   irdl.type @type1 {
   // CHECK:   %[[v0:[^ ]*]] = irdl.any
-  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @type2<%[[v0]]>
+  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @testd::@type2<%[[v0]]>
   // CHECK:   %[[v2:[^ ]*]] = irdl.is i32
   // CHECK:   %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
   // CHECK:   irdl.parameters(%[[v3]])
   irdl.type @type1 {
     %0 = irdl.any
-    %1 = irdl.parametric @type2<%0>
+    %1 = irdl.parametric @testd::@type2<%0>
     %2 = irdl.is i32
     %3 = irdl.any_of(%1, %2)
     irdl.parameters(%3)
@@ -36,13 +36,13 @@ irdl.dialect @testd {
 
   // CHECK:   irdl.type @type2 {
   // CHECK:   %[[v0:[^ ]*]] = irdl.any
-  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @type1<%[[v0]]>
+  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @testd::@type1<%[[v0]]>
   // CHECK:   %[[v2:[^ ]*]] = irdl.is i32
   // CHECK:   %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
   // CHECK:   irdl.parameters(%[[v3]])
   irdl.type @type2 {
       %0 = irdl.any
-      %1 = irdl.parametric @type1<%0>
+      %1 = irdl.parametric @testd::@type1<%0>
       %2 = irdl.is i32
       %3 = irdl.any_of(%1, %2)
       irdl.parameters(%3)

diff  --git a/mlir/test/Dialect/IRDL/invalid.irdl.mlir b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
index f207d31cf158b..93ad619358750 100644
--- a/mlir/test/Dialect/IRDL/invalid.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
@@ -2,8 +2,6 @@
 
 // Testing invalid IRDL IRs
 
-func.func private @foo()
-
 irdl.dialect @testd {
   irdl.type @type {
     // expected-error at +1 {{symbol '@foo' not found}}
@@ -44,15 +42,12 @@ irdl.dialect @testd {
 
 // -----
 
+func.func private @not_a_type_or_attr()
+
 irdl.dialect @invalid_parametric {
   irdl.operation @foo {
     // expected-error at +1 {{symbol '@not_a_type_or_attr' does not refer to a type or attribute definition}}
     %param = irdl.parametric @not_a_type_or_attr<>
     irdl.results(%param)
   }
-
-  irdl.operation @not_a_type_or_attr {
-    %param = irdl.is i1
-    irdl.results(%param)
-  }
 }

diff  --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir
index f828d95bdb81d..aeb1a83747ecc 100644
--- a/mlir/test/Dialect/IRDL/testd.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir
@@ -76,20 +76,20 @@ irdl.dialect @testd {
   }
 
   // CHECK: irdl.operation @dyn_type_base {
-  // CHECK:   %[[v1:[^ ]*]] = irdl.base @parametric
+  // CHECK:   %[[v1:[^ ]*]] = irdl.base @testd::@parametric
   // CHECK:   irdl.results(%[[v1]])
   // CHECK: }
   irdl.operation @dyn_type_base {
-    %0 = irdl.base @parametric
+    %0 = irdl.base @testd::@parametric
     irdl.results(%0)
   }
 
   // CHECK: irdl.operation @dyn_attr_base {
-  // CHECK:   %[[v1:[^ ]*]] = irdl.base @parametric_attr
+  // CHECK:   %[[v1:[^ ]*]] = irdl.base @testd::@parametric_attr
   // CHECK:   irdl.attributes {"attr1" = %[[v1]]}
   // CHECK: }
   irdl.operation @dyn_attr_base {
-    %0 = irdl.base @parametric_attr
+    %0 = irdl.base @testd::@parametric_attr
     irdl.attributes {"attr1" = %0}
   }
 
@@ -115,14 +115,14 @@ irdl.dialect @testd {
   // CHECK:   %[[v0:[^ ]*]] = irdl.is i32
   // CHECK:   %[[v1:[^ ]*]] = irdl.is i64
   // CHECK:   %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
-  // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @parametric<%[[v2]]>
+  // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @testd::@parametric<%[[v2]]>
   // CHECK:   irdl.results(%[[v3]])
   // CHECK: }
   irdl.operation @dynparams {
     %0 = irdl.is i32
     %1 = irdl.is i64
     %2 = irdl.any_of(%0, %1)
-    %3 = irdl.parametric @parametric<%2>
+    %3 = irdl.parametric @testd::@parametric<%2>
     irdl.results(%3)
   }
 


        


More information about the Mlir-commits mailing list