[Mlir-commits] [mlir] [mlir][PDL] Add CallableOpInterface to pdl.pattern and inlining support to pdl (PR #172071)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 12 11:13:55 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-ods

Author: Fabian Mora (fabianmcg)

<details>
<summary>Changes</summary>

This commit enables inlining of calls within PDL patterns by:

1. Adding CallableOpInterface to PatternOp, and implementing the required
   interface methods (getCallableRegion, getArgumentTypes, getResultTypes)
   and the ArgAndResultAttrsOpInterface stubs to make pdl.pattern a
   valid callable.

2. Adding the dialect inliner interface that marks all operations as legal
  to inline.

This is particularly useful for nonmaterializable patterns that may
contain func.call operations to external functions defining pattern
matching or rewrite logic. After inlining, these patterns can be
transformed into standard materializable PDL patterns.

NOTE: The pattern op needs to be marked callable as the inliner doesn't
allow inlining if there's no callable ancestor.

Example:
```mlir
func.func private @<!-- -->pattern_body() -> (!pdl.type, !pdl.type, !pdl.operation) {
  %0 = pdl.type : i32
  %1 = pdl.type
  %2 = pdl.operation  -> (%0, %1 : !pdl.type, !pdl.type)
  return %0, %1, %2 : !pdl.type, !pdl.type, !pdl.operation
}
func.func private @<!-- -->rewrite_body(%arg0: !pdl.type, %arg1: !pdl.type, %arg2: !pdl.operation) {
  %0 = pdl.operation "foo.op"  -> (%arg0, %arg1 : !pdl.type, !pdl.type)
  pdl.apply_native_rewrite "NativeRewrite"(%0, %arg2 : !pdl.operation, !pdl.operation)
  return
}
pdl.pattern @<!-- -->nonmaterializable_pattern : benefit(1) nonmaterializable {
  %0:3 = func.call @<!-- -->pattern_body() : () -> (!pdl.type, !pdl.type, !pdl.operation)
  rewrite %0#<!-- -->2 {
    func.call @<!-- -->rewrite_body(%0#<!-- -->0, %0#<!-- -->1, %0#<!-- -->2) : (!pdl.type, !pdl.type, !pdl.operation) -> ()
  }
}
// mlir-opt --inline
pdl.pattern @<!-- -->nonmaterializable_pattern : benefit(1) nonmaterializable {
  %0 = type : i32
  %1 = type
  %2 = operation  -> (%0, %1 : !pdl.type, !pdl.type)
  rewrite %2 {
    %3 = operation "foo.op"  -> (%0, %1 : !pdl.type, !pdl.type)
    apply_native_rewrite "NativeRewrite"(%3, %2 : !pdl.operation, !pdl.operation)
  }
}
```

---
Full diff: https://github.com/llvm/llvm-project/pull/172071.diff


12 Files Affected:

- (modified) mlir/include/mlir/Dialect/PDL/IR/PDLOps.h (+1) 
- (modified) mlir/include/mlir/Dialect/PDL/IR/PDLOps.td (+83-11) 
- (modified) mlir/include/mlir/IR/OpBase.td (+8) 
- (modified) mlir/include/mlir/IR/OpDefinition.h (+21) 
- (modified) mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp (+10) 
- (modified) mlir/lib/Dialect/PDL/IR/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/PDL/IR/PDL.cpp (+27-2) 
- (added) mlir/test/Conversion/PDLToPDLInterp/invalid.mlir (+14) 
- (added) mlir/test/Dialect/PDL/inlining.mlir (+30) 
- (modified) mlir/test/Dialect/PDL/ops.mlir (+26) 
- (modified) mlir/test/IR/traits.mlir (+42) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.h b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.h
index 22935827519d3..c72bde94d3dd0 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.h
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.h
@@ -18,6 +18,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index 6ee638c19d1ad..ca22d0bee6bac 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -16,6 +16,7 @@
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
+include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -30,7 +31,7 @@ class PDL_Op<string mnemonic, list<Trait> traits = []>
 //===----------------------------------------------------------------------===//
 
 def PDL_ApplyNativeConstraintOp
-    : PDL_Op<"apply_native_constraint", [HasParent<"pdl::PatternOp">]> {
+    : PDL_Op<"apply_native_constraint", [HasParentNotOf<"pdl::RewriteOp">]> {
   let summary = "Apply a native constraint to a set of provided entities";
   let description = [{
     `pdl.apply_native_constraint` operations apply a native C++ constraint, that
@@ -62,7 +63,7 @@ def PDL_ApplyNativeConstraintOp
 //===----------------------------------------------------------------------===//
 
 def PDL_ApplyNativeRewriteOp
-    : PDL_Op<"apply_native_rewrite", [HasParent<"pdl::RewriteOp">]> {
+    : PDL_Op<"apply_native_rewrite", [HasParentNotOf<"pdl::PatternOp">]> {
   let summary = "Apply a native rewrite method inside of pdl.rewrite region";
   let description = [{
     `pdl.apply_native_rewrite` operations apply a native C++ function, that has
@@ -150,7 +151,7 @@ def PDL_AttributeOp : PDL_Op<"attribute"> {
 // pdl::EraseOp
 //===----------------------------------------------------------------------===//
 
-def PDL_EraseOp : PDL_Op<"erase", [HasParent<"pdl::RewriteOp">]> {
+def PDL_EraseOp : PDL_Op<"erase", [HasParentNotOf<"pdl::PatternOp">]> {
   let summary = "Mark an input operation as `erased`";
   let description = [{
     `pdl.erase` operations are used within `pdl.rewrite` regions to specify that
@@ -172,7 +173,7 @@ def PDL_EraseOp : PDL_Op<"erase", [HasParent<"pdl::RewriteOp">]> {
 //===----------------------------------------------------------------------===//
 
 def PDL_OperandOp
-    : PDL_Op<"operand", [HasParent<"pdl::PatternOp">]> {
+    : PDL_Op<"operand", [HasParentNotOf<"pdl::RewriteOp">]> {
   let summary = "Define an external input operand in a pattern";
   let description = [{
     `pdl.operand` operations capture external operand edges into an operation
@@ -211,7 +212,7 @@ def PDL_OperandOp
 //===----------------------------------------------------------------------===//
 
 def PDL_OperandsOp
-    : PDL_Op<"operands", [HasParent<"pdl::PatternOp">]> {
+    : PDL_Op<"operands", [HasParentNotOf<"pdl::RewriteOp">]> {
   let summary = "Define a range of input operands in a pattern";
   let description = [{
     `pdl.operands` operations capture external operand range edges into an
@@ -393,7 +394,8 @@ def PDL_OperationOp : PDL_Op<"operation", [AttrSizedOperandSegments]> {
 
 def PDL_PatternOp : PDL_Op<"pattern", [
     IsolatedFromAbove, SingleBlock, Symbol,
-    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getDefaultDialect"]>,
+    CallableOpInterface
   ]> {
   let summary = "Define a rewrite pattern";
   let description = [{
@@ -403,6 +405,11 @@ def PDL_PatternOp : PDL_Op<"pattern", [
     the pattern is specified within the region body, with the rewrite provided
     by a terminating `pdl.rewrite`.
 
+    The `nonmaterializable` attribute indicates that the pattern cannot be
+    materialized in a PDL backend. This is used for patterns that require
+    further transformations before they can be lowered, e.g. patterns that
+    contain non-PDL operations.
+
     Example:
 
     ```mlir
@@ -420,15 +427,19 @@ def PDL_PatternOp : PDL_Op<"pattern", [
   }];
 
   let arguments = (ins ConfinedAttr<I16Attr, [IntNonNegative]>:$benefit,
-                       OptionalAttr<SymbolNameAttr>:$sym_name);
+                       OptionalAttr<SymbolNameAttr>:$sym_name,
+                       UnitAttr:$nonmaterializable);
   let regions = (region SizedRegion<1>:$bodyRegion);
   let assemblyFormat = [{
-    ($sym_name^)? `:` `benefit` `(` $benefit `)` attr-dict-with-keyword $bodyRegion
+    ($sym_name^)? `:` `benefit` `(` $benefit `)`
+    (`nonmaterializable` $nonmaterializable^)?
+    attr-dict-with-keyword $bodyRegion
   }];
 
   let builders = [
     OpBuilder<(ins CArg<"std::optional<uint16_t>", "1">:$benefit,
-                   CArg<"std::optional<StringRef>", "std::nullopt">:$name)>,
+                   CArg<"std::optional<StringRef>", "std::nullopt">:$name,
+                   CArg<"bool", "false">:$nonmaterializable)>,
   ];
   let extraClassDeclaration = [{
     //===------------------------------------------------------------------===//
@@ -440,6 +451,67 @@ def PDL_PatternOp : PDL_Op<"pattern", [
 
     /// Returns the rewrite operation of this pattern.
     RewriteOp getRewriter();
+  
+    //===------------------------------------------------------------------===//
+    // ArgAndResultAttrsOpInterface Methods
+    //===------------------------------------------------------------------===//
+    /// Get the arguments attributes.
+    ArrayAttr getArgAttrsAttr() {
+      return nullptr;
+    }
+
+    /// Get the result attributes.
+    ArrayAttr getResAttrsAttr() {
+      return nullptr;
+    }
+
+    /// Set the argument attributes.
+    /// NOTE: PDL patterns do not support argument attributes, calling this
+    /// method will assert.
+    void setArgAttrsAttr(Attribute attrs) {
+        (void)attrs;
+        assert(false && "PDL patterns do not support argument attributes.");
+    }
+  
+    /// Set the result attributes.
+    /// NOTE: PDL patterns do not support argument attributes, calling this
+    /// method will assert.
+    void setResAttrsAttr(Attribute attrs) {
+      (void)attrs;
+      assert(false && "PDL patterns do not support result attributes.");
+    }
+
+    /// Remove the argument attributes.
+    /// NOTE: this method is a no-op for PDL patterns.
+    Attribute removeArgAttrsAttr() {
+      return nullptr;
+    }
+
+    /// Remove the result attributes.
+    /// NOTE: this method is a no-op for PDL patterns.
+    Attribute removeResAttrsAttr() {
+      return nullptr;
+    }
+
+    //===------------------------------------------------------------------===//
+    // CallableOpInterface Methods
+    //===------------------------------------------------------------------===//
+    /// Get the callable region.
+    Region * getCallableRegion() {
+      return &getBodyRegion();
+    }
+
+    /// Get the argument types.
+    ArrayRef<Type> getArgumentTypes() {
+      // Patterns take no arguments.
+      return {};
+    }
+
+    /// Get the result types.
+    ArrayRef<Type> getResultTypes() {
+      // Patterns return no results.
+      return {};
+    }
   }];
   let hasRegionVerifier = 1;
 }
@@ -448,7 +520,7 @@ def PDL_PatternOp : PDL_Op<"pattern", [
 // pdl::RangeOp
 //===----------------------------------------------------------------------===//
 
-def PDL_RangeOp : PDL_Op<"range", [Pure, HasParent<"pdl::RewriteOp">]> {
+def PDL_RangeOp : PDL_Op<"range", [Pure, HasParentNotOf<"pdl::PatternOp">]> {
   let summary = "Construct a range of pdl entities";
   let description = [{
     `pdl.range` operations construct a range from a given set of PDL entities,
@@ -491,7 +563,7 @@ def PDL_RangeOp : PDL_Op<"range", [Pure, HasParent<"pdl::RewriteOp">]> {
 //===----------------------------------------------------------------------===//
 
 def PDL_ReplaceOp : PDL_Op<"replace", [
-    AttrSizedOperandSegments, HasParent<"pdl::RewriteOp">
+    AttrSizedOperandSegments, HasParentNotOf<"pdl::PatternOp">
   ]> {
   let summary = "Mark an input operation as `replaced`";
   let description = [{
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 8d7dafae0ee76..5ff2b6d8a035c 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -141,6 +141,14 @@ class ParentOneOf<list<string> ops>
     : ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>,
       StructuralOpTrait;
 
+// Op's parent operation is not the provided one.
+class HasParentNotOf<string op>
+    : ParamNativeOpTrait<"HasParentNotOf", op>, StructuralOpTrait;
+
+class ParentNotOneOf<list<string> ops>
+    : ParamNativeOpTrait<"HasParentNotOf", !interleave(ops, ", ")>,
+      StructuralOpTrait;
+
 // Op result type is derived from the first attribute. If the attribute is an
 // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the
 // attribute content is used.
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index be92fe0a6c7e3..78e25b13ebadf 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1323,6 +1323,27 @@ struct HasParent {
   };
 };
 
+/// This class provides a verifier for ops that are expecting their parent
+/// to not be one of the given parent ops
+template <typename... ParentOpTypes>
+struct HasParentNotOf {
+  template <typename ConcreteType>
+  class Impl : public TraitBase<ConcreteType, Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      if (Operation *parentOp = op->getParentOp();
+          !parentOp || !llvm::isa<ParentOpTypes...>(parentOp))
+        return success();
+
+      return op->emitOpError()
+             << "expects parent op "
+             << (sizeof...(ParentOpTypes) != 1 ? "to not be one of '"
+                                               : "to not be '")
+             << llvm::ArrayRef({ParentOpTypes::getOperationName()...}) << "'";
+    }
+  };
+};
+
 /// A trait for operations that have an attribute specifying operand segments.
 ///
 /// Certain operations can have multiple variadic operands and their size
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index a4c66e125f6bf..a1cb7ea9d1709 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -991,6 +991,16 @@ struct PDLToPDLInterpPass
 void PDLToPDLInterpPass::runOnOperation() {
   ModuleOp module = getOperation();
 
+  // Check there are no non-materializable patterns.
+  for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
+    if (pattern.getNonmaterializable()) {
+      pattern.emitError()
+          << "pdl_interp backend does not support non-materializable "
+             "patterns";
+      return signalPassFailure();
+    }
+  }
+
   // Create the main matcher function This function contains all of the match
   // related functionality from patterns in the module.
   OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
diff --git a/mlir/lib/Dialect/PDL/IR/CMakeLists.txt b/mlir/lib/Dialect/PDL/IR/CMakeLists.txt
index a0bec9f51a623..8c1a7f3197c58 100644
--- a/mlir/lib/Dialect/PDL/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/PDL/IR/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRPDLDialect
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRCallInterfaces
   MLIRInferTypeOpInterface
   MLIRSideEffectInterfaces
   )
diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index 8af93335ca96c..a817874836c70 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include <optional>
 
@@ -23,12 +24,31 @@ using namespace mlir::pdl;
 // PDLDialect
 //===----------------------------------------------------------------------===//
 
+namespace {
+/// This class defines the interface for handling inlining for pdl
+/// dialect operations.
+struct PDLInlinerInterface : public DialectInlinerInterface {
+  using DialectInlinerInterface::DialectInlinerInterface;
+  // Everything can be inlined.
+  bool isLegalToInline(Operation *, Operation *, bool) const final {
+    return true;
+  }
+  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
+    return true;
+  }
+  bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
+    return true;
+  }
+};
+} // namespace
+
 void PDLDialect::initialize() {
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
       >();
   registerTypes();
+  addInterfaces<PDLInlinerInterface>();
 }
 
 //===----------------------------------------------------------------------===//
@@ -322,6 +342,10 @@ LogicalResult PatternOp::verifyRegions() {
         .append("see terminator defined here");
   }
 
+  // Skip if the pattern is marked non-materializable.
+  if (getNonmaterializable())
+    return success();
+
   // Check that all values defined in the top-level pattern belong to the PDL
   // dialect.
   WalkResult result = body.walk([&](Operation *op) -> WalkResult {
@@ -385,9 +409,10 @@ LogicalResult PatternOp::verifyRegions() {
 
 void PatternOp::build(OpBuilder &builder, OperationState &state,
                       std::optional<uint16_t> benefit,
-                      std::optional<StringRef> name) {
+                      std::optional<StringRef> name, bool nonmaterializable) {
   build(builder, state, builder.getI16IntegerAttr(benefit.value_or(0)),
-        name ? builder.getStringAttr(*name) : StringAttr());
+        name ? builder.getStringAttr(*name) : StringAttr(),
+        nonmaterializable ? builder.getUnitAttr() : UnitAttr());
   state.regions[0]->emplaceBlock();
 }
 
diff --git a/mlir/test/Conversion/PDLToPDLInterp/invalid.mlir b/mlir/test/Conversion/PDLToPDLInterp/invalid.mlir
new file mode 100644
index 0000000000000..3ff79a679f5ed
--- /dev/null
+++ b/mlir/test/Conversion/PDLToPDLInterp/invalid.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -convert-pdl-to-pdl-interp --verify-diagnostics
+
+func.func private @pattern_body() -> (!pdl.type, !pdl.type, !pdl.operation)
+func.func private @rewrite_body(!pdl.type, !pdl.type, !pdl.operation)
+
+// expected-error at below {{pdl_interp backend does not support non-materializable patterns}}
+pdl.pattern @nonmaterializable_pattern : benefit(1) nonmaterializable {
+  %type1, %type2, %root = func.call @pattern_body()
+    : () -> (!pdl.type, !pdl.type, !pdl.operation)
+  rewrite %root {
+    func.call @rewrite_body(%type1, %type2, %root)
+      : (!pdl.type, !pdl.type, !pdl.operation) -> ()
+  }
+}
diff --git a/mlir/test/Dialect/PDL/inlining.mlir b/mlir/test/Dialect/PDL/inlining.mlir
new file mode 100644
index 0000000000000..285bef85817d7
--- /dev/null
+++ b/mlir/test/Dialect/PDL/inlining.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt -inline %s | FileCheck %s
+
+func.func private @pattern_body() -> (!pdl.type, !pdl.type, !pdl.operation) {
+  %0 = pdl.type : i32
+  %1 = pdl.type
+  %2 = pdl.operation  -> (%0, %1 : !pdl.type, !pdl.type)
+  return %0, %1, %2 : !pdl.type, !pdl.type, !pdl.operation
+}
+
+func.func private @rewrite_body(%arg0: !pdl.type, %arg1: !pdl.type, %arg2: !pdl.operation) {
+  %0 = pdl.operation "foo.op"  -> (%arg0, %arg1 : !pdl.type, !pdl.type)
+  pdl.apply_native_rewrite "NativeRewrite"(%0, %arg2 : !pdl.operation, !pdl.operation)
+  return
+}
+
+// CHECK-LABEL:   pdl.pattern @nonmaterializable_pattern : benefit(1) nonmaterializable {
+// CHECK:           %[[VAL_0:.*]] = type : i32
+// CHECK:           %[[VAL_1:.*]] = type
+// CHECK:           %[[VAL_2:.*]] = operation  -> (%[[VAL_0]], %[[VAL_1]] : !pdl.type, !pdl.type)
+// CHECK:           rewrite %[[VAL_2]] {
+// CHECK:             %[[VAL_3:.*]] = operation "foo.op"  -> (%[[VAL_0]], %[[VAL_1]] : !pdl.type, !pdl.type)
+// CHECK:             apply_native_rewrite "NativeRewrite"(%[[VAL_3]], %[[VAL_2]] : !pdl.operation, !pdl.operation)
+// CHECK:           }
+// CHECK:         }
+pdl.pattern @nonmaterializable_pattern : benefit(1) nonmaterializable {
+  %0:3 = func.call @pattern_body() : () -> (!pdl.type, !pdl.type, !pdl.operation)
+  rewrite %0#2 {
+    func.call @rewrite_body(%0#0, %0#1, %0#2) : (!pdl.type, !pdl.type, !pdl.operation) -> ()
+  }
+}
diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir
index 20e40deea5f86..4d579e264f43b 100644
--- a/mlir/test/Dialect/PDL/ops.mlir
+++ b/mlir/test/Dialect/PDL/ops.mlir
@@ -173,3 +173,29 @@ pdl.pattern @attribute_with_loc : benefit(1) {
   %root = operation {"attribute" = %attr}
   rewrite %root with "rewriter"
 }
+
+// -----
+
+// Check that non-materializable patterns allow non-PDL operations in their
+// bodies.
+func.func @pattern_body() -> (!pdl.type, !pdl.type, !pdl.operation) {
+  %type1 = pdl.type : i32
+  %type2 = pdl.type
+  %root = pdl.operation -> (%type1, %type2 : !pdl.type, !pdl.type)
+  return %type1, %type2, %root : !pdl.type, !pdl.type, !pdl.operation
+}
+
+func.func @rewrite_body(%type1: !pdl.type, %type2: !pdl.type, %root: !pdl.operation) {
+  %newOp = pdl.operation "foo.op" -> (%type1, %type2 : !pdl.type, !pdl.type)
+  pdl.apply_native_rewrite "NativeRewrite"(%newOp, %root : !pdl.operation, !pdl.operation)
+  return
+}
+
+pdl.pattern @nonmaterializable_pattern : benefit(1) nonmaterializable {
+  %type1, %type2, %root = func.call @pattern_body()
+    : () -> (!pdl.type, !pdl.type, !pdl.operation)
+  rewrite %root {
+    func.call @rewrite_body(%type1, %type2, %root)
+      : (!pdl.type, !pdl.type, !pdl.operation) -> ()
+  }
+}
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 49cfd7e496746..1dc55940e0d3f 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -303,6 +303,48 @@ func.func @failedParentOneOf_wrong_parent1() {
    }) : () -> ()
 }
 
+// -----
+
+func.func @failedParentNotOf() {
+  "test.parent"() ({
+   // expected-error at +1 {{expects parent op to not be 'test.parent'}}
+    "test.not_child"() : () -> ()
+  }) : () -> ()
+}
+
+// -----
+
+// CHECK: succeededChildWithParentNotOf
+func.func @succeededChildWithParentNotOf() {
+  "test.not_child"() : () -> ()
+  return
+}
+
+// -----
+
+// CHECK: succeededChildWithParentNotOneOf
+func.func @succeededChildWithParentNotOneOf() {
+  "test.child_with_parent_not_one_of"() : () -> ()
+  return
+}
+
+// -----
+
+func.func @failedParent1NotOneOf() {
+  "test.parent1"() ({
+   // expected-error at +1 {{expects parent op to not be one of 'test.parent, test.parent1'}}
+    "test.child_with_parent_not_one_of"() : () -> ()
+   }) : () -> ()
+}
+
+// -----
+
+func.func @failedParentNotOneOf() {
+  "test.parent"() ({
+   // expected-error at +1 {{expects parent op to not be one of 'test.parent, test.parent1'}}
+    "test.child_with_parent_not_one_of"() : () -> ()
+   }) : () -> ()
+}
 
 // -----
 
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5417ae94f00d7..b514a00873ff6 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -861,6 +861,7 @@ def ParentOp : TEST_Op<"parent"> {
     let regions = (region AnyRegion);
 }
 def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>;
+def NotChildOp : TEST_Op<"not_child", [HasParentNotOf<"ParentOp">]>;
 
 // ParentOneOf trait
 def ParentOp1 : TEST_Op<"parent1"> {
@@ -868,6 +869,8 @@ def ParentOp1 : TEST_Op<"parent1"> {
 }
 def ChildWithParentOneOf : TEST_Op<"child_with_parent_one_of",
                                 [ParentOneOf<["ParentOp", "ParentOp1"]>]>;
+def NotChildWithParentOneOf : TEST_Op<"child_with_parent_not_one_of",
+                                [ParentNotOneOf<["ParentOp", "ParentOp1"]>]>;
 
 def TerminatorOp : TEST_Op<"finish", [Terminator]>;
 def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator",

``````````

</details>


https://github.com/llvm/llvm-project/pull/172071


More information about the Mlir-commits mailing list