[Mlir-commits] [mlir] 1eb6994 - [mlir][PDL] Add support for variadic operands and results in PDL

River Riddle llvmlistbot at llvm.org
Tue Mar 16 13:20:32 PDT 2021


Author: River Riddle
Date: 2021-03-16T13:20:18-07:00
New Revision: 1eb6994d6ab18d5f6555acf515d27e2076fbea8a

URL: https://github.com/llvm/llvm-project/commit/1eb6994d6ab18d5f6555acf515d27e2076fbea8a
DIFF: https://github.com/llvm/llvm-project/commit/1eb6994d6ab18d5f6555acf515d27e2076fbea8a.diff

LOG: [mlir][PDL] Add support for variadic operands and results in PDL

This revision extends the PDL dialect to add support for variadic operands and results, with ranges of these values represented via the recently added !pdl.range type. To support this extension, three new operations have been added that closely match the single variant:
* pdl.operands : Define a range of input operands.
* pdl.results : Extract a result group from an operation.
* pdl.types : Define a handle to a range of types.

Support for these in the pdl interpreter dialect and byte code will be added in followup revisions.

Differential Revision: https://reviews.llvm.org/D95721

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
    mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
    mlir/lib/Dialect/PDL/IR/PDL.cpp
    mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
    mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
    mlir/test/Dialect/PDL/invalid.mlir
    mlir/test/Dialect/PDL/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index 74f3fce08933..32de9f438c00 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -60,7 +60,7 @@ def PDL_ApplyNativeConstraintOp
 
   let builders = [
     OpBuilder<(ins "StringRef":$name, CArg<"ValueRange", "{}">:$args,
-      CArg<"ArrayRef<Attribute>", "{}">:$params), [{
+                   CArg<"ArrayRef<Attribute>", "{}">:$params), [{
       build($_builder, $_state, $_builder.getStringAttr(name), args,
             params.empty() ? ArrayAttr() : $_builder.getArrayAttr(params));
     }]>,
@@ -196,9 +196,9 @@ def PDL_OperandOp : PDL_Op<"operand", [HasParent<"pdl::PatternOp">]> {
   let description = [{
     `pdl.operand` operations capture external operand edges into an operation
     node that originate from operations or block arguments not otherwise
-    specified within the pattern (e.g. via `pdl.result`). These operations
-    define individual operands of a given operation. A `pdl.operand` may
-    partially constrain an operand by specifying an expected value type
+    specified within the pattern (i.e. via `pdl.result` or `pdl.results`). These
+    operations define individual operands of a given operation. A `pdl.operand`
+    may partially constrain an operand by specifying an expected value type
     (via a `pdl.type` operation).
 
     Example:
@@ -224,6 +224,44 @@ def PDL_OperandOp : PDL_Op<"operand", [HasParent<"pdl::PatternOp">]> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// pdl::OperandsOp
+//===----------------------------------------------------------------------===//
+
+def PDL_OperandsOp : PDL_Op<"operands", [HasParent<"pdl::PatternOp">]> {
+  let summary = "Define a range of input operands in a pattern";
+  let description = [{
+    `pdl.operands` operations capture external operand range edges into an
+    operation node that originate from operations or block arguments not
+    otherwise specified within the pattern (i.e. via `pdl.result` or
+    `pdl.results`). These operations define groups of input operands into a
+    given operation. A `pdl.operands` may partially constrain a set of input
+    operands by specifying expected value types (via `pdl.types` operations).
+
+    Example:
+
+    ```mlir
+    // Define a range of input operands:
+    %operands = pdl.operands
+
+    // Define a range of input operands with expected types:
+    %types = pdl.types : [i32, i64, i32]
+    %typed_operands = pdl.operands : %types
+    ```
+  }];
+
+  let arguments = (ins Optional<PDL_RangeOf<PDL_Type>>:$type);
+  let results = (outs PDL_RangeOf<PDL_Value>:$val);
+  let assemblyFormat = "(`:` $type^)? attr-dict";
+
+  let builders = [
+    OpBuilder<(ins), [{
+      build($_builder, $_state, RangeType::get($_builder.getType<ValueType>()),
+            Value());
+    }]>,
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // pdl::OperationOp
 //===----------------------------------------------------------------------===//
@@ -245,6 +283,14 @@ def PDL_OperationOp
     a handle to the operation itself. Handles to the results of the operation
     can be extracted via `pdl.result`.
 
+    Example:
+
+    ```mlir
+    // Define an instance of a `foo.op` operation.
+    %op = pdl.operation "foo.op"(%arg0, %arg1 : !pdl.value, !pdl.value)
+      {"attrA" = %attr0} -> (%type, %type : !pdl.type, !pdl.type)
+    ```
+
     When used within a matching context, the name of the operation may be
     omitted.
 
@@ -257,24 +303,78 @@ def PDL_OperationOp
     override the `InferTypeOpInterface` to ensure that the result types can be
     inferred.
 
-    Example:
+    The operands of the operation are interpreted in the following ways:
+
+    1) A single !pdl.range<value>:
+
+    In this case, the single range is treated as all of the operands of the
+    operation.
 
     ```mlir
-    // Define an instance of a `foo.op` operation.
-    %op = pdl.operation "foo.op"(%arg0, %arg1) {"attrA" = %attr0} -> %type, %type, %type, %type
+    // Define an instance with single range of operands.
+    %op = pdl.operation "std.return"(%allArgs : !pdl.range<value>)
+    ```
+
+    2) A variadic number of either !pdl.value or !pdl.range<value>:
+
+    In this case, the inputs are expected to correspond with the operand groups
+    defined on the operation in ODS.
+
+    ```tablgen
+    // Given the following operation definition in ODS:
+    def MyIndirectCallOp {
+      let results = (outs FunctionType:$call, Variadic<AnyType>:$args);
+    }
+    ```
+
+    ```mlir
+    // We can match the operands as so:
+    %op = pdl.operation "my.indirect_call"(%call, %args : !pdl.value, !pdl.range<value>)
+    ```
+
+    The results of the operation are interpreted in the following ways:
+
+    1) A single !pdl.range<type>:
+
+    In this case, the single range is treated as all of the result types of the
+    operation.
+
+    ```mlir
+    // Define an instance with single range of types.
+    %allResultTypes = pdl.types
+    %op = pdl.operation "unrealized_conversion_cast" -> (%allResultTypes : !pdl.types)
+    ```
+
+    2) A variadic number of either !pdl.type or !pdl.range<type>:
+
+    In this case, the inputs are expected to correspond with the result groups
+    defined on the operation in ODS.
+
+    ```tablgen
+    // Given the following operation definition in ODS:
+    def MyOp {
+      let results = (outs SomeType:$result, Variadic<SomeType>:$otherResults);
+    }
+    ```
+
+    ```mlir
+    // We can match the results as so:
+    %result = pdl.type
+    %otherResults = pdl.types
+    %op = pdl.operation "foo.op" -> (%result, %otherResults : !pdl.type, !pdl.range<type>)
     ```
   }];
 
   let arguments = (ins OptionalAttr<StrAttr>:$name,
-                       Variadic<PDL_Value>:$operands,
+                       Variadic<PDL_InstOrRangeOf<PDL_Value>>:$operands,
                        Variadic<PDL_Attribute>:$attributes,
                        StrArrayAttr:$attributeNames,
-                       Variadic<PDL_Type>:$types);
+                       Variadic<PDL_InstOrRangeOf<PDL_Type>>:$types);
   let results = (outs PDL_Operation:$op);
   let assemblyFormat = [{
-    ($name^)? (`(` $operands^ `)`)?
+    ($name^)? (`(` $operands^ `:` type($operands) `)`)?
     custom<OperationOpAttributes>($attributes, $attributeNames)
-    (`->` $types^)? attr-dict
+    (`->` `(` $types^ `:` type($types) `)`)? attr-dict
   }];
 
   let builders = [
@@ -378,7 +478,10 @@ def PDL_ReplaceOp : PDL_Op<"replace", [
 
     ```mlir
     // Replace root node with 2 values:
-    pdl.replace %root with (%val0, %val1)
+    pdl.replace %root with (%val0, %val1 : !pdl.value, !pdl.value)
+
+    // Replace root node with a range of values:
+    pdl.replace %root with (%vals : !pdl.range<value>)
 
     // Replace root with another operation:
     pdl.replace %root with %otherOp
@@ -386,9 +489,10 @@ def PDL_ReplaceOp : PDL_Op<"replace", [
   }];
   let arguments = (ins PDL_Operation:$operation,
                        Optional<PDL_Operation>:$replOperation,
-                       Variadic<PDL_Value>:$replValues);
+                       Variadic<PDL_InstOrRangeOf<PDL_Value>>:$replValues);
   let assemblyFormat = [{
-    $operation `with` (`(` $replValues^ `)`)? ($replOperation^)? attr-dict
+    $operation `with` (`(` $replValues^ `:` type($replValues) `)`)?
+    ($replOperation^)? attr-dict
   }];
 }
 
@@ -409,13 +513,13 @@ def PDL_ResultOp : PDL_Op<"result"> {
     ```mlir
     // Extract a result:
     %operation = pdl.operation ...
-    %result = pdl.result 1 of %operation
+    %pdl_result = pdl.result 1 of %operation
 
     // Imagine the following IR being matched:
     %result_0, %result_1 = foo.op ...
 
     // If the example pattern snippet above were matching against `foo.op` in
-    // the IR snippted, `%result` would correspond to `%result_1`.
+    // the IR snippet, `%pdl_result` would correspond to `%result_1`.
     ```
   }];
 
@@ -425,6 +529,48 @@ def PDL_ResultOp : PDL_Op<"result"> {
   let verifier = ?;
 }
 
+//===----------------------------------------------------------------------===//
+// pdl::ResultsOp
+//===----------------------------------------------------------------------===//
+
+def PDL_ResultsOp : PDL_Op<"results"> {
+  let summary = "Extract a result group from an operation";
+  let description = [{
+    `pdl.results` operations extract a result group from an operation within a
+    pattern or rewrite region. If an index is provided, this operation extracts
+    a result group as defined by the ODS definition of the operation. In this
+    case the result of this operation may be either a single `pdl.value` or
+    a `pdl.range<value>`, depending on the constraint of the result in ODS. If
+    no index is provided, this operation extracts the full result range of the
+    operation.
+
+    Example:
+
+    ```mlir
+    // Extract all of the results of an operation:
+    %operation = pdl.operation ...
+    %results = pdl.results of %operation
+
+    // Extract the results in the first result group of an operation, which is
+    // variadic:
+    %operation = pdl.operation ...
+    %results = pdl.results 0 of %operation -> !pdl.range<value>
+
+    // Extract the results in the second result group of an operation, which is
+    // not variadic:
+    %operation = pdl.operation ...
+    %results = pdl.results 1 of %operation -> !pdl.value
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$parent, OptionalAttr<I32Attr>:$index);
+  let results = (outs PDL_InstOrRangeOf<PDL_Value>:$val);
+  let assemblyFormat = [{
+    ($index^)? `of` $parent custom<ResultsValueType>(ref($index), type($val))
+    attr-dict
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // pdl::RewriteOp
 //===----------------------------------------------------------------------===//
@@ -489,7 +635,7 @@ def PDL_RewriteEndOp : PDL_Op<"rewrite_end", [Terminator,
 def PDL_TypeOp : PDL_Op<"type"> {
   let summary = "Define a type handle within a pattern";
   let description = [{
-    `pdl.type` operations capture result type constraints of an `Attributes`,
+    `pdl.type` operations capture result type constraints of `Attributes`,
     `Values`, and `Operations`. Instances of this operation define, and
     partially constrain, results types of a given entity. A `pdl.type` may
     partially constrain the result by specifying a constant `Type`.
@@ -498,23 +644,44 @@ def PDL_TypeOp : PDL_Op<"type"> {
 
     ```mlir
     // Define a type:
-    %attr = pdl.type
+    %type = pdl.type
 
     // Define a type with a constant value:
-    %attr = pdl.type : i32
+    %type = pdl.type : i32
     ```
   }];
 
   let arguments = (ins OptionalAttr<TypeAttr>:$type);
   let results = (outs PDL_Type:$result);
   let assemblyFormat = "attr-dict (`:` $type^)?";
+}
 
-  let builders = [
-    OpBuilder<(ins CArg<"Type", "Type()">:$ty), [{
-      build($_builder, $_state, $_builder.getType<AttributeType>(),
-            ty ? TypeAttr::get(ty) : TypeAttr());
-    }]>,
-  ];
+//===----------------------------------------------------------------------===//
+// pdl::TypesOp
+//===----------------------------------------------------------------------===//
+
+def PDL_TypesOp : PDL_Op<"types"> {
+  let summary = "Define a range of type handles within a pattern";
+  let description = [{
+    `pdl.types` operations capture result type constraints of `Value`s, and
+    `Operation`s. Instances of this operation define results types of a given
+    entity. A `pdl.types` may partially constrain the results by specifying
+    an array of `Type`s.
+
+    Example:
+
+    ```mlir
+    // Define a range of types:
+    %types = pdl.types
+
+    // Define a range of types with a range of constant values:
+    %types = pdl.types : [i32, i64, i32]
+    ```
+  }];
+
+  let arguments = (ins OptionalAttr<TypeArrayAttr>:$types);
+  let results = (outs PDL_RangeOf<PDL_Type>:$result);
+  let assemblyFormat = "attr-dict (`:` $types^)?";
 }
 
 #endif // MLIR_DIALECT_PDL_IR_PDLOPS

diff  --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
index c854616fbc8f..1e0578339ad8 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
@@ -101,4 +101,18 @@ def PDL_AnyType : Type<
   CPred<"$_self.isa<::mlir::pdl::PDLType>()">, "pdl type",
         "::mlir::pdl::PDLType">;
 
+// A range of positional values of one of the provided types.
+class PDL_RangeOf<Type positionalType> :
+  ContainerType<AnyTypeOf<[positionalType]>, PDL_Range.predicate,
+                "$_self.cast<::mlir::pdl::RangeType>().getElementType()",
+                "range", "::mlir::pdl::RangeType">,
+    BuildableType<"::mlir::pdl::RangeType::get(" # positionalType.builderCall #
+                  ")">;
+
+// Either a positional value or a range of positional values for a given type.
+class PDL_InstOrRangeOf<Type positionalType> :
+    AnyTypeOf<[positionalType, PDL_RangeOf<positionalType>],
+              "single element or range of " # positionalType.summary,
+              "::mlir::pdl::PDLType">;
+
 #endif // MLIR_DIALECT_PDL_IR_PDLTYPES

diff  --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index dc1f501825bd..8164c89dac54 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -35,13 +35,19 @@ void PDLDialect::initialize() {
 /// Returns true if the given operation is used by a "binding" pdl operation
 /// within the main matcher body of a `pdl.pattern`.
 static bool hasBindingUseInMatcher(Operation *op, Block *matcherBlock) {
-  for (Operation *user : op->getUsers()) {
+  for (OpOperand &use : op->getUses()) {
+    Operation *user = use.getOwner();
     if (user->getBlock() != matcherBlock)
       continue;
-    if (isa<AttributeOp, OperandOp, OperationOp, RewriteOp>(user))
+    if (isa<AttributeOp, OperandOp, OperandsOp, OperationOp>(user))
+      return true;
+    // Only the first operand of RewriteOp may be bound to, i.e. the root
+    // operation of the pattern.
+    if (isa<RewriteOp>(user) && use.getOperandNumber() == 0)
       return true;
     // A result by itself is not binding, it must also be bound.
-    if (isa<ResultOp>(user) && hasBindingUseInMatcher(user, matcherBlock))
+    if (isa<ResultOp, ResultsOp>(user) &&
+        hasBindingUseInMatcher(user, matcherBlock))
       return true;
   }
   return false;
@@ -107,6 +113,14 @@ static LogicalResult verify(OperandOp op) {
   return verifyHasBindingUseInMatcher(op);
 }
 
+//===----------------------------------------------------------------------===//
+// pdl::OperandsOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(OperandsOp op) {
+  return verifyHasBindingUseInMatcher(op);
+}
+
 //===----------------------------------------------------------------------===//
 // pdl::OperationOp
 //===----------------------------------------------------------------------===//
@@ -177,18 +191,18 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
     if (isa<ApplyNativeRewriteOp>(resultTypeOp))
       continue;
 
-    // If the type is already constrained, there is nothing to do.
-    TypeOp typeOp = cast<TypeOp>(resultTypeOp);
-    if (typeOp.type())
-      continue;
-
     // If the type operation was defined in the matcher and constrains the
     // result of an input operation, it can be used.
     auto constrainsInputOp = [rewriterBlock](Operation *user) {
       return user->getBlock() != rewriterBlock && isa<OperationOp>(user);
     };
-    if (llvm::any_of(typeOp.getResult().getUsers(), constrainsInputOp))
-      continue;
+    if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
+      if (typeOp.type() || llvm::any_of(typeOp->getUsers(), constrainsInputOp))
+        continue;
+    } else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
+      if (typeOp.types() || llvm::any_of(typeOp->getUsers(), constrainsInputOp))
+        continue;
+    }
 
     return op
         .emitOpError("must have inferable or constrained result types when "
@@ -296,6 +310,36 @@ static LogicalResult verify(ReplaceOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// pdl::ResultsOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index,
+                                         Type &resultType) {
+  if (!index) {
+    resultType = RangeType::get(p.getBuilder().getType<ValueType>());
+    return success();
+  }
+  if (p.parseArrow() || p.parseType(resultType))
+    return failure();
+  return success();
+}
+
+static void printResultsValueType(OpAsmPrinter &p, ResultsOp op,
+                                  IntegerAttr index, Type resultType) {
+  if (index)
+    p << " -> " << resultType;
+}
+
+static LogicalResult verify(ResultsOp op) {
+  if (!op.index() && op.getType().isa<pdl::ValueType>()) {
+    return op.emitOpError() << "expected `pdl.range<value>` result type when "
+                               "no index is specified, but got: "
+                            << op.getType();
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // pdl::RewriteOp
 //===----------------------------------------------------------------------===//
@@ -340,6 +384,14 @@ static LogicalResult verify(TypeOp op) {
       op, "`pdl.attribute`, `pdl.operand`, or `pdl.operation`");
 }
 
+//===----------------------------------------------------------------------===//
+// pdl::TypesOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(TypesOp op) {
+  return verifyHasBindingUseInMatcher(op, "`pdl.operands`, or `pdl.operation`");
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
index a42b51604945..0792f76cba7a 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
@@ -27,7 +27,7 @@ module @simple {
   // CHECK:     pdl_interp.apply_rewrite "rewriter"(%[[REWRITE_ROOT]]
   // CHECK:     pdl_interp.finalize
   pdl.pattern : benefit(1) {
-    %root = pdl.operation "foo.op"()
+    %root = pdl.operation "foo.op"
     pdl.rewrite %root with "rewriter"
   }
 }
@@ -69,7 +69,7 @@ module @constraints {
   pdl.pattern : benefit(1) {
     %input0 = pdl.operand
     %input1 = pdl.operand
-    %root = pdl.operation(%input0, %input1)
+    %root = pdl.operation(%input0, %input1 : !pdl.value, !pdl.value)
     %result0 = pdl.result 0 of %root
 
     pdl.apply_native_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value)
@@ -96,7 +96,7 @@ module @inputs {
   pdl.pattern : benefit(1) {
     %type = pdl.type : i64
     %input = pdl.operand : %type
-    %root = pdl.operation(%input, %input)
+    %root = pdl.operation(%input, %input : !pdl.value, !pdl.value)
     pdl.rewrite %root with "rewriter"
   }
 }
@@ -120,7 +120,7 @@ module @results {
   pdl.pattern : benefit(1) {
     %type1 = pdl.type : i32
     %type2 = pdl.type
-    %root = pdl.operation -> %type1, %type2
+    %root = pdl.operation -> (%type1, %type2 : !pdl.type, !pdl.type)
     pdl.rewrite %root with "rewriter"
   }
 }
@@ -149,11 +149,11 @@ module @results_as_operands {
   pdl.pattern : benefit(1) {
     %type1 = pdl.type : i32
     %type2 = pdl.type
-    %inputOp = pdl.operation -> %type1, %type2
+    %inputOp = pdl.operation -> (%type1, %type2 : !pdl.type, !pdl.type)
     %result1 = pdl.result 0 of %inputOp
     %result2 = pdl.result 1 of %inputOp
 
-    %root = pdl.operation(%result1, %result2)
+    %root = pdl.operation(%result1, %result2 : !pdl.value, !pdl.value)
     pdl.rewrite %root with "rewriter"
   }
 }
@@ -168,12 +168,12 @@ module @switch_result_types {
   // CHECK:   pdl_interp.switch_type %[[RESULT_TYPE]] to [i32, i64]
   pdl.pattern : benefit(1) {
     %type = pdl.type : i32
-    %root = pdl.operation -> %type
+    %root = pdl.operation -> (%type : !pdl.type)
     pdl.rewrite %root with "rewriter"
   }
   pdl.pattern : benefit(1) {
     %type = pdl.type : i64
-    %root = pdl.operation -> %type
+    %root = pdl.operation -> (%type : !pdl.type)
     pdl.rewrite %root with "rewriter"
   }
 }
@@ -195,13 +195,13 @@ module @predicate_ordering  {
   pdl.pattern : benefit(1) {
     %resultType = pdl.type
     pdl.apply_native_constraint "typeConstraint"[](%resultType : !pdl.type)
-    %root = pdl.operation -> %resultType
+    %root = pdl.operation -> (%resultType : !pdl.type)
     pdl.rewrite %root with "rewriter"
   }
 
   pdl.pattern : benefit(1) {
     %resultType = pdl.type
-    %apply = pdl.operation -> %resultType
+    %apply = pdl.operation -> (%resultType : !pdl.type)
     pdl.rewrite %apply with "rewriter"
   }
 }

diff  --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
index 3d0d565c547f..67ac7c811ab7 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
@@ -9,7 +9,7 @@ module @external {
   // CHECK:     pdl_interp.apply_rewrite "rewriter" [true](%[[ROOT]], %[[INPUT]] : !pdl.operation, !pdl.value)
   pdl.pattern : benefit(1) {
     %input = pdl.operand
-    %root = pdl.operation "foo.op"(%input)
+    %root = pdl.operation "foo.op"(%input : !pdl.value)
     pdl.rewrite %root with "rewriter"[true](%input : !pdl.value)
   }
 }
@@ -60,12 +60,12 @@ module @operation_operands {
   // CHECK:     pdl_interp.create_operation "foo.op2"(%[[OPERAND1]])
   pdl.pattern : benefit(1) {
     %operand = pdl.operand
-    %root = pdl.operation "foo.op"(%operand)
+    %root = pdl.operation "foo.op"(%operand : !pdl.value)
     pdl.rewrite %root {
       %type = pdl.type : i32
-      %newOp = pdl.operation "foo.op"(%operand) -> %type
+      %newOp = pdl.operation "foo.op"(%operand : !pdl.value) -> (%type : !pdl.type)
       %result = pdl.result 0 of %newOp
-      %newOp1 = pdl.operation "foo.op2"(%result)
+      %newOp1 = pdl.operation "foo.op2"(%result : !pdl.value)
       pdl.erase %root
     }
   }
@@ -82,12 +82,12 @@ module @operation_operands {
   // CHECK:     pdl_interp.create_operation "foo.op2"(%[[OPERAND1]])
   pdl.pattern : benefit(1) {
     %operand = pdl.operand
-    %root = pdl.operation "foo.op"(%operand)
+    %root = pdl.operation "foo.op"(%operand : !pdl.value)
     pdl.rewrite %root {
       %type = pdl.type : i32
-      %newOp = pdl.operation "foo.op"(%operand) -> %type
+      %newOp = pdl.operation "foo.op"(%operand : !pdl.value) -> (%type : !pdl.type)
       %result = pdl.result 0 of %newOp
-      %newOp1 = pdl.operation "foo.op2"(%result)
+      %newOp1 = pdl.operation "foo.op2"(%result : !pdl.value)
       pdl.erase %root
     }
   }
@@ -103,10 +103,10 @@ module @operation_result_types {
   pdl.pattern : benefit(1) {
     %rootType = pdl.type
     %rootType1 = pdl.type
-    %root = pdl.operation "foo.op" -> %rootType, %rootType1
+    %root = pdl.operation "foo.op" -> (%rootType, %rootType1 : !pdl.type, !pdl.type)
     pdl.rewrite %root {
       %newType1 = pdl.type
-      %newOp = pdl.operation "foo.op" -> %rootType, %newType1
+      %newOp = pdl.operation "foo.op" -> (%rootType, %newType1 : !pdl.type, !pdl.type)
       pdl.replace %root with %newOp
     }
   }
@@ -123,9 +123,9 @@ module @replace_with_op {
   // CHECK:     pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]])
   pdl.pattern : benefit(1) {
     %type = pdl.type : i32
-    %root = pdl.operation "foo.op" -> %type
+    %root = pdl.operation "foo.op" -> (%type : !pdl.type)
     pdl.rewrite %root {
-      %newOp = pdl.operation "foo.op" -> %type
+      %newOp = pdl.operation "foo.op" -> (%type : !pdl.type)
       pdl.replace %root with %newOp
     }
   }
@@ -142,11 +142,11 @@ module @replace_with_values {
   // CHECK:     pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]])
   pdl.pattern : benefit(1) {
     %type = pdl.type : i32
-    %root = pdl.operation "foo.op" -> %type
+    %root = pdl.operation "foo.op" -> (%type : !pdl.type)
     pdl.rewrite %root {
-      %newOp = pdl.operation "foo.op" -> %type
+      %newOp = pdl.operation "foo.op" -> (%type : !pdl.type)
       %newResult = pdl.result 0 of %newOp
-      pdl.replace %root with (%newResult)
+      pdl.replace %root with (%newResult : !pdl.value)
     }
   }
 }
@@ -178,10 +178,10 @@ module @apply_native_rewrite {
   // CHECK:     pdl_interp.create_operation "foo.op"() -> %[[TYPE]]
   pdl.pattern : benefit(1) {
     %type = pdl.type
-    %root = pdl.operation "foo.op" -> %type
+    %root = pdl.operation "foo.op" -> (%type : !pdl.type)
     pdl.rewrite %root {
       %newType = pdl.apply_native_rewrite "functor"[true](%root : !pdl.operation) : !pdl.type
-      %newOp = pdl.operation "foo.op" -> %newType
+      %newOp = pdl.operation "foo.op" -> (%newType : !pdl.type)
       pdl.replace %root with %newOp
     }
   }

diff  --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir
index a054da24ba4d..e371d8408670 100644
--- a/mlir/test/Dialect/PDL/invalid.mlir
+++ b/mlir/test/Dialect/PDL/invalid.mlir
@@ -38,7 +38,7 @@ pdl.pattern : benefit(1) {
   // expected-error at below {{expected only one of [`type`, `value`] to be set}}
   %attr = pdl.attribute : %type 10
 
-  %op = pdl.operation "foo.op" {"attr" = %attr} -> %type
+  %op = pdl.operation "foo.op" {"attr" = %attr} -> (%type : !pdl.type)
   pdl.rewrite %op with "rewriter"
 }
 
@@ -90,6 +90,20 @@ pdl.pattern : benefit(1) {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// pdl::OperandsOp
+//===----------------------------------------------------------------------===//
+
+pdl.pattern : benefit(1) {
+  // expected-error at below {{expected a bindable (i.e. `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}}
+  %unused = pdl.operands
+
+  %op = pdl.operation "foo.op"
+  pdl.rewrite %op with "rewriter"
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // pdl::OperationOp
 //===----------------------------------------------------------------------===//
@@ -116,13 +130,13 @@ pdl.pattern : benefit(1) {
 // -----
 
 pdl.pattern : benefit(1) {
-  %op = pdl.operation "foo.op"()
+  %op = pdl.operation "foo.op"
   pdl.rewrite %op {
     %type = pdl.type
 
     // expected-error at below {{op must have inferable or constrained result types when nested within `pdl.rewrite`}}
     // expected-note at below {{result type #0 was not constrained}}
-    %newOp = pdl.operation "foo.op" -> %type
+    %newOp = pdl.operation "foo.op" -> (%type : !pdl.type)
   }
 }
 
@@ -163,9 +177,9 @@ pdl.pattern : benefit(1) {
 
 pdl.pattern : benefit(1) {
   %type = pdl.type : i32
-  %root = pdl.operation "foo.op" -> %type
+  %root = pdl.operation "foo.op" -> (%type : !pdl.type)
   pdl.rewrite %root {
-    %newOp = pdl.operation "foo.op" -> %type
+    %newOp = pdl.operation "foo.op" -> (%type : !pdl.type)
     %newResult = pdl.result 0 of %newOp
 
     // expected-error at below {{expected no replacement values to be provided when the replacement operation is present}}
@@ -177,6 +191,19 @@ pdl.pattern : benefit(1) {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// pdl::ResultsOp
+//===----------------------------------------------------------------------===//
+
+pdl.pattern : benefit(1) {
+  %root = pdl.operation "foo.op"
+  // expected-error at below {{expected `pdl.range<value>` result type when no index is specified, but got: '!pdl.value'}}
+  %results = "pdl.results"(%root) : (!pdl.operation) -> !pdl.value
+  pdl.rewrite %root with "rewriter"
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // pdl::RewriteOp
 //===----------------------------------------------------------------------===//
@@ -237,3 +264,17 @@ pdl.pattern : benefit(1) {
   %op = pdl.operation "foo.op"
   pdl.rewrite %op with "rewriter"
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// pdl::TypesOp
+//===----------------------------------------------------------------------===//
+
+pdl.pattern : benefit(1) {
+  // expected-error at below {{expected a bindable (i.e. `pdl.operands`, or `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}}
+  %unused = pdl.types
+
+  %op = pdl.operation "foo.op"
+  pdl.rewrite %op with "rewriter"
+}

diff  --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir
index d376f001fcfa..07e98f9e5868 100644
--- a/mlir/test/Dialect/PDL/ops.mlir
+++ b/mlir/test/Dialect/PDL/ops.mlir
@@ -8,12 +8,12 @@ pdl.pattern @operations : benefit(1) {
   // Operation with attributes and results.
   %attribute = pdl.attribute
   %type = pdl.type
-  %op0 = pdl.operation {"attr" = %attribute} -> %type
+  %op0 = pdl.operation {"attr" = %attribute} -> (%type : !pdl.type)
   %op0_result = pdl.result 0 of %op0
 
   // Operation with input.
   %input = pdl.operand
-  %root = pdl.operation(%op0_result, %input)
+  %root = pdl.operation(%op0_result, %input : !pdl.value, !pdl.value)
   pdl.rewrite %root with "rewriter"
 }
 
@@ -21,7 +21,7 @@ pdl.pattern @operations : benefit(1) {
 
 pdl.pattern @rewrite_with_args : benefit(1) {
   %input = pdl.operand
-  %root = pdl.operation(%input)
+  %root = pdl.operation(%input : !pdl.value)
   pdl.rewrite %root with "rewriter"(%input : !pdl.value)
 }
 
@@ -36,7 +36,7 @@ pdl.pattern @rewrite_with_params : benefit(1) {
 
 pdl.pattern @rewrite_with_args_and_params : benefit(1) {
   %input = pdl.operand
-  %root = pdl.operation(%input)
+  %root = pdl.operation(%input : !pdl.value)
   pdl.rewrite %root with "rewriter"["I am param"](%input : !pdl.value)
 }
 
@@ -47,10 +47,10 @@ pdl.pattern @rewrite_with_args_and_params : benefit(1) {
 pdl.pattern @infer_type_from_operation_replace : benefit(1) {
   %type1 = pdl.type : i32
   %type2 = pdl.type
-  %root = pdl.operation -> %type1, %type2
+  %root = pdl.operation -> (%type1, %type2 : !pdl.type, !pdl.type)
   pdl.rewrite %root {
     %type3 = pdl.type
-    %newOp = pdl.operation "foo.op" -> %type1, %type3
+    %newOp = pdl.operation "foo.op" -> (%type1, %type3 : !pdl.type, !pdl.type)
     pdl.replace %root with %newOp
   }
 }
@@ -58,12 +58,25 @@ pdl.pattern @infer_type_from_operation_replace : benefit(1) {
 // -----
 
 // Check that the result type of an operation within a rewrite can be inferred
-// from a pdl.replace.
+// from types used within the match block.
 pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
   %type1 = pdl.type : i32
   %type2 = pdl.type
-  %root = pdl.operation -> %type1, %type2
+  %root = pdl.operation -> (%type1, %type2 : !pdl.type, !pdl.type)
+  pdl.rewrite %root {
+    %newOp = pdl.operation "foo.op" -> (%type1, %type2 : !pdl.type, !pdl.type)
+  }
+}
+
+// -----
+
+// Check that the result type of an operation within a rewrite can be inferred
+// from types used within the match block.
+pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
+  %types = pdl.types
+  %root = pdl.operation -> (%types : !pdl.range<type>)
   pdl.rewrite %root {
-    %newOp = pdl.operation "foo.op" -> %type1, %type2
+    %otherTypes = pdl.types : [i32, i64]
+    %newOp = pdl.operation "foo.op" -> (%types, %otherTypes : !pdl.range<type>, !pdl.range<type>)
   }
 }


        


More information about the Mlir-commits mailing list