[Mlir-commits] [mlir] [mlir] use irdl as matcher description in transform (PR #89779)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Wed Apr 24 01:36:30 PDT 2024


https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/89779

>From f373d5c6e7b3fd3df1266ac0759e31a1190d0ce0 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Tue, 23 Apr 2024 15:19:09 +0000
Subject: [PATCH] [mlir] use irdl as matcher description in transform

Introduce a new Transform dialect extension that uses IRDL op
definitions as matcher descriptors. IRDL allows one to essentially
define additional op constraits to be verified and, unlike PDL, does not
assume rewriting will happen. Leverage IRDL verification capability to
filter out ops that match an IRDL definition without actually
registering the corresponding operation with the system.
---
 .../include/mlir/Dialect/IRDL/IRDLVerifiers.h | 11 +++
 .../mlir/Dialect/Transform/CMakeLists.txt     |  1 +
 .../Transform/IRDLExtension/CMakeLists.txt    |  6 ++
 .../Transform/IRDLExtension/IRDLExtension.h   | 21 +++++
 .../IRDLExtension/IRDLExtensionOps.h          | 20 +++++
 .../IRDLExtension/IRDLExtensionOps.td         | 36 ++++++++
 .../Transform/PDLExtension/PDLExtension.h     |  5 ++
 mlir/include/mlir/InitAllExtensions.h         |  2 +
 mlir/lib/Dialect/IRDL/IRDLLoading.cpp         | 71 ++++++++++------
 mlir/lib/Dialect/Transform/CMakeLists.txt     |  1 +
 .../Transform/IRDLExtension/CMakeLists.txt    | 12 +++
 .../Transform/IRDLExtension/IRDLExtension.cpp | 34 ++++++++
 .../IRDLExtension/IRDLExtensionOps.cpp        | 84 +++++++++++++++++++
 mlir/test/Dialect/Transform/irdl.mlir         | 25 ++++++
 14 files changed, 302 insertions(+), 27 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Transform/IRDLExtension/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h
 create mode 100644 mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h
 create mode 100644 mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.td
 create mode 100644 mlir/lib/Dialect/Transform/IRDLExtension/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp
 create mode 100644 mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp
 create mode 100644 mlir/test/Dialect/Transform/irdl.mlir

diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
index 9ecb7c0107d7f8..89e99a63a5f104 100644
--- a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
@@ -30,7 +30,10 @@ class DynamicTypeDefinition;
 namespace mlir {
 namespace irdl {
 
+class AttributeOp;
 class Constraint;
+class OperationOp;
+class TypeOp;
 
 /// Provides context to the verification of constraints.
 /// It contains the assignment of variables to attributes, and the assignment
@@ -246,6 +249,14 @@ struct RegionConstraint {
   std::optional<SmallVector<unsigned>> argumentConstraints;
   std::optional<size_t> blockCount;
 };
+
+/// Generate an op verifier function from the given IRDL operation definition.
+llvm::unique_function<LogicalResult(Operation *) const> createVerifier(
+    OperationOp operation,
+    const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>>
+        &typeDefs,
+    const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
+        &attrDefs);
 } // namespace irdl
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
index 0cd71ec6919d9e..b6155b5f573f1b 100644
--- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_subdirectory(DebugExtension)
 add_subdirectory(Interfaces)
 add_subdirectory(IR)
+add_subdirectory(IRDLExtension)
 add_subdirectory(LoopExtension)
 add_subdirectory(PDLExtension)
 add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Transform/IRDLExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/IRDLExtension/CMakeLists.txt
new file mode 100644
index 00000000000000..dfcd906b43af04
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IRDLExtension/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS IRDLExtensionOps.td)
+mlir_tablegen(IRDLExtensionOps.h.inc -gen-op-decls)
+mlir_tablegen(IRDLExtensionOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTransformDialectIRDLExtensionOpsIncGen)
+
+add_mlir_doc(IRDLExtensionOps IRDLExtensionOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h b/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h
new file mode 100644
index 00000000000000..19684e1ed44468
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h
@@ -0,0 +1,21 @@
+//===- IRDLExtension.h - IRDL extension for Transform dialect ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H
+#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace transform {
+/// Registers the IRDL extension of the Transform dialect in the given registry.
+void registerIRDLExtension(DialectRegistry &dialectRegistry);
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H
diff --git a/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h b/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h
new file mode 100644
index 00000000000000..7e1d5cad1fbd88
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h
@@ -0,0 +1,20 @@
+//===- IRDLExtensionOps.h - IRDL Transform dialect extension ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H
+#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h.inc"
+
+#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H
diff --git a/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.td b/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.td
new file mode 100644
index 00000000000000..6ca624aeda12c7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.td
@@ -0,0 +1,36 @@
+//===- IRDLExtensionOps.td - Transform dialect extension ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS
+#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/SymbolInterfaces.td"
+
+def IRDLCollectMatchingOp : TransformDialectOp<"irdl.collect_matching",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     SymbolTable,
+     NoTerminator]> {
+  let summary = 
+    "Finds ops that match the IRDL definition without registering them.";
+
+  let arguments = (ins TransformHandleTypeInterface:$root);
+  let regions = (region SizedRegion<1>:$body);
+  let results = (outs TransformHandleTypeInterface:$matched);
+
+  let assemblyFormat =
+    "`in` $root `:` functional-type(operands, results) attr-dict-with-keyword "
+    "regions";
+
+  let hasVerifier = 1;
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS
diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h
index 08915213cd22c5..bf5a105bc9f29b 100644
--- a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h
+++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h
@@ -6,6 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H
+#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H
+
 namespace mlir {
 class DialectRegistry;
 
@@ -14,3 +17,5 @@ namespace transform {
 void registerPDLExtension(DialectRegistry &dialectRegistry);
 } // namespace transform
 } // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 7708ca5571de3b..20a4ab6f18a286 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -35,6 +35,7 @@
 #include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
 #include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
 #include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
 #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
@@ -77,6 +78,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   sparse_tensor::registerTransformDialectExtension(registry);
   tensor::registerTransformDialectExtension(registry);
   transform::registerDebugExtension(registry);
+  transform::registerIRDLExtension(registry);
   transform::registerLoopExtension(registry);
   transform::registerPDLExtension(registry);
   vector::registerTransformDialectExtension(registry);
diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index cfc8d092c8178a..5df2b45d8037b3 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -270,26 +270,30 @@ static LogicalResult irdlRegionVerifier(
   return success();
 }
 
-/// Define and load an operation represented by a `irdl.operation`
-/// operation.
-static WalkResult loadOperation(
-    OperationOp op, ExtensibleDialect *dialect,
-    DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
-    DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
+llvm::unique_function<LogicalResult(Operation *) const>
+mlir::irdl::createVerifier(
+    OperationOp op,
+    const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
+    const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
+        &attrs) {
   // Resolve SSA values to verifier constraint slots
   SmallVector<Value> constrToValue;
   SmallVector<Value> regionToValue;
   for (Operation &op : op->getRegion(0).getOps()) {
     if (isa<VerifyConstraintInterface>(op)) {
-      if (op.getNumResults() != 1)
-        return op.emitError()
-               << "IRDL constraint operations must have exactly one result";
+      if (op.getNumResults() != 1) {
+        op.emitError()
+            << "IRDL constraint operations must have exactly one result";
+        return nullptr;
+      }
       constrToValue.push_back(op.getResult(0));
     }
     if (isa<VerifyRegionInterface>(op)) {
-      if (op.getNumResults() != 1)
-        return op.emitError()
-               << "IRDL constraint operations must have exactly one result";
+      if (op.getNumResults() != 1) {
+        op.emitError()
+            << "IRDL constraint operations must have exactly one result";
+        return nullptr;
+      }
       regionToValue.push_back(op.getResult(0));
     }
   }
@@ -302,7 +306,7 @@ static WalkResult loadOperation(
     std::unique_ptr<Constraint> verifier =
         op.getVerifier(constrToValue, types, attrs);
     if (!verifier)
-      return WalkResult::interrupt();
+      return nullptr;
     constraints.push_back(std::move(verifier));
   }
 
@@ -358,7 +362,7 @@ static WalkResult loadOperation(
   }
 
   // Gather which constraint slots correspond to attributes constraints
-  DenseMap<StringAttr, size_t> attributesContraints;
+  DenseMap<StringAttr, size_t> attributeConstraints;
   auto attributesOp = op.getOp<AttributesOp>();
   if (attributesOp.has_value()) {
     const Operation::operand_range values = attributesOp->getAttributeValues();
@@ -367,40 +371,53 @@ static WalkResult loadOperation(
     for (const auto &[name, value] : llvm::zip(names, values)) {
       for (auto [i, constr] : enumerate(constrToValue)) {
         if (constr == value) {
-          attributesContraints[cast<StringAttr>(name)] = i;
+          attributeConstraints[cast<StringAttr>(name)] = i;
           break;
         }
       }
     }
   }
 
-  // IRDL does not support defining custom parsers or printers.
-  auto parser = [](OpAsmParser &parser, OperationState &result) {
-    return failure();
-  };
-  auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
-    printer.printGenericOp(op);
-  };
-
-  auto verifier =
+  return
       [constraints{std::move(constraints)},
        regionConstraints{std::move(regionConstraints)},
        operandConstraints{std::move(operandConstraints)},
        operandVariadicity{std::move(operandVariadicity)},
        resultConstraints{std::move(resultConstraints)},
        resultVariadicity{std::move(resultVariadicity)},
-       attributesContraints{std::move(attributesContraints)}](Operation *op) {
+       attributeConstraints{std::move(attributeConstraints)}](Operation *op) {
         ConstraintVerifier verifier(constraints);
         const LogicalResult opVerifierResult = irdlOpVerifier(
             op, verifier, operandConstraints, operandVariadicity,
-            resultConstraints, resultVariadicity, attributesContraints);
+            resultConstraints, resultVariadicity, attributeConstraints);
         const LogicalResult opRegionVerifierResult =
             irdlRegionVerifier(op, verifier, regionConstraints);
         return LogicalResult::success(opVerifierResult.succeeded() &&
                                       opRegionVerifierResult.succeeded());
       };
+}
+
+/// Define and load an operation represented by a `irdl.operation`
+/// operation.
+static WalkResult loadOperation(
+    OperationOp op, ExtensibleDialect *dialect,
+    const DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
+    const DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
+        &attrs) {
+
+  // IRDL does not support defining custom parsers or printers.
+  auto parser = [](OpAsmParser &parser, OperationState &result) {
+    return failure();
+  };
+  auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
+    printer.printGenericOp(op);
+  };
+
+  auto verifier = createVerifier(op, types, attrs);
+  if (!verifier)
+    return WalkResult::interrupt();
 
-  // IRDL supports only checking number of blocks and argument contraints
+  // IRDL supports only checking number of blocks and argument constraints
   // It is done in the main verifier to reuse `ConstraintVerifier` context
   auto regionVerifier = [](Operation *op) { return LogicalResult::success(); };
 
diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt
index 64115dcc29d639..0c0d5ebe0c212e 100644
--- a/mlir/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_subdirectory(DebugExtension)
 add_subdirectory(Interfaces)
 add_subdirectory(IR)
+add_subdirectory(IRDLExtension)
 add_subdirectory(LoopExtension)
 add_subdirectory(PDLExtension)
 add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Transform/IRDLExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/IRDLExtension/CMakeLists.txt
new file mode 100644
index 00000000000000..9216a3d722021f
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/IRDLExtension/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_dialect_library(MLIRTransformDialectIRDLExtension
+  IRDLExtension.cpp
+  IRDLExtensionOps.cpp
+
+  DEPENDS
+  MLIRTransformDialectIRDLExtensionOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRTransformDialect
+  MLIRIRDL
+)
diff --git a/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp b/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp
new file mode 100644
index 00000000000000..94004365b8a1a5
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp
@@ -0,0 +1,34 @@
+//===- IRDLExtension.cpp - IRDL extension for the Transform dialect -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h"
+#include "mlir/IR/DialectRegistry.h"
+
+using namespace mlir;
+
+namespace {
+class IRDLExtension
+    : public transform::TransformDialectExtension<IRDLExtension> {
+public:
+  void init() {
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp.inc"
+        >();
+
+    declareDependentDialect<irdl::IRDLDialect>();
+  }
+};
+} // namespace
+
+void mlir::transform::registerIRDLExtension(DialectRegistry &dialectRegistry) {
+  dialectRegistry.addExtensions<IRDLExtension>();
+}
diff --git a/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp
new file mode 100644
index 00000000000000..9cc579e65edf91
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp
@@ -0,0 +1,84 @@
+//===- IRDLExtensionOps.cpp - IRDL extension for the Transform dialect ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/STLExtras.h"
+
+using namespace mlir;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp.inc"
+
+namespace mlir::transform {
+
+DiagnosedSilenceableFailure
+IRDLCollectMatchingOp::apply(TransformRewriter &rewriter,
+                             TransformResults &results, TransformState &state) {
+  auto dialect = cast<irdl::DialectOp>(getBody().front().front());
+  Block &body = dialect.getBody().front();
+  irdl::OperationOp operation = *body.getOps<irdl::OperationOp>().begin();
+  auto verifier = irdl::createVerifier(
+      operation,
+      DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>>(),
+      DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>());
+
+  auto handlerID = getContext()->getDiagEngine().registerHandler(
+      [](Diagnostic &) { return success(); });
+  SmallVector<Operation *> matched;
+  for (Operation *payload : state.getPayloadOps(getRoot())) {
+    payload->walk([&](Operation *target) {
+      if (succeeded(verifier(target))) {
+        matched.push_back(target);
+      }
+    });
+  }
+  getContext()->getDiagEngine().eraseHandler(handlerID);
+  results.set(cast<OpResult>(getMatched()), matched);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void IRDLCollectMatchingOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getRoot(), effects);
+  producesHandle(getMatched(), effects);
+  onlyReadsPayload(effects);
+}
+
+LogicalResult IRDLCollectMatchingOp::verify() {
+  Block &bodyBlock = getBody().front();
+  if (!llvm::hasSingleElement(bodyBlock))
+    return emitOpError() << "expects a single operation in the body";
+
+  auto dialect = dyn_cast<irdl::DialectOp>(bodyBlock.front());
+  if (!dialect) {
+    return emitOpError() << "expects the body operation to be "
+                         << irdl::DialectOp::getOperationName();
+  }
+
+  // TODO: relax this by taking a symbol name of the operation to match, note
+  // that symbol name is also the name of the operation and we may want to
+  // divert from that to have constraints on-the-fly using IRDL.
+  auto irdlOperations = dialect.getOps<irdl::OperationOp>();
+  if (!llvm::hasSingleElement(irdlOperations))
+    return emitOpError() << "expects IRDL to contain exactly one operation";
+
+  if (!dialect.getOps<irdl::TypeOp>().empty() ||
+      !dialect.getOps<irdl::AttributeOp>().empty()) {
+    return emitOpError() << "IRDL types and attributes are not yet supported";
+  }
+
+  return success();
+}
+
+} // namespace mlir::transform
diff --git a/mlir/test/Dialect/Transform/irdl.mlir b/mlir/test/Dialect/Transform/irdl.mlir
new file mode 100644
index 00000000000000..d3faea0dffcc26
--- /dev/null
+++ b/mlir/test/Dialect/Transform/irdl.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    %0 = transform.irdl.collect_matching in %arg0 : (!transform.any_op) -> (!transform.any_op){
+    ^bb0(%arg1: !transform.any_op):
+      irdl.dialect @test {
+        irdl.operation @whatever {
+          %0 = irdl.is i32
+          %1 = irdl.is i64
+          %2 = irdl.any_of(%0, %1)
+          irdl.results(%2)
+        }
+      }
+    }
+    transform.debug.emit_remark_at %0, "matched" : !transform.any_op
+    transform.yield
+  }
+
+  // expected-remark @below {{matched}}
+  "test.whatever"() : () -> i32
+  "test.whatever"() : () -> f32
+  // expected-remark @below {{matched}}
+  "test.whatever"() : () -> i64
+}



More information about the Mlir-commits mailing list