[Mlir-commits] [mlir] 3c75228 - [mlir:PDLInterp] Refactor the implementation of result type inferrence

River Riddle llvmlistbot at llvm.org
Sun May 1 12:46:54 PDT 2022


Author: River Riddle
Date: 2022-05-01T12:25:05-07:00
New Revision: 3c752289912895e067eb173485cadce6c618d6d4

URL: https://github.com/llvm/llvm-project/commit/3c752289912895e067eb173485cadce6c618d6d4
DIFF: https://github.com/llvm/llvm-project/commit/3c752289912895e067eb173485cadce6c618d6d4.diff

LOG: [mlir:PDLInterp] Refactor the implementation of result type inferrence

The current implementation uses a discrete "pdl_interp.inferred_types"
operation, which acts as a "fake" handle to a type range. This op is
used as a signal to pdl_interp.create_operation that types should be
inferred. This is terribly awkward and clunky though:

* This op doesn't have a byte code representation, and its conversion
  to bytecode kind of assumes that it is only used in a certain way. The
  current lowering is also broken and seemingly untested.

* Given that this is a different operation, it gives off the assumption
  that it can be used multiple times, or that after the first use
  the value contains the inferred types. This isn't the case though,
  the resultant type range can never actually be used as a type range.

This commit refactors the representation by removing the discrete
InferredTypesOp, and instead adds a UnitAttr to
pdl_interp.CreateOperation that signals when the created operations
should infer their types. This leads to a much much cleaner abstraction,
a more optimal bytecode lowering, and also allows for better error
handling and diagnostics when a created operation doesn't actually
support type inferrence.

Differential Revision: https://reviews.llvm.org/D124587

Added: 
    mlir/test/Dialect/PDLInterp/invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
    mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
    mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
    mlir/lib/Rewrite/ByteCode.cpp
    mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
    mlir/test/Dialect/PDLInterp/ops.mlir
    mlir/test/Rewrite/pdl-bytecode.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index fbc73c0708723..44dc6cdf12219 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -409,14 +409,18 @@ def PDLInterp_CreateOperationOp
   let description = [{
     `pdl_interp.create_operation` operations create an `Operation` instance with
     the specified attributes, operands, and result types. See `pdl.operation`
-    for a more detailed description on the interpretation of the arguments to
-    this operation.
+    for a more detailed description on the general interpretation of the arguments
+    to this operation.
 
     Example:
 
     ```mlir
     // Create an instance of a `foo.op` operation.
     %op = pdl_interp.create_operation "foo.op"(%arg0 : !pdl.value) {"attrA" = %attr0} -> (%type : !pdl.type)
+    
+    // Create an instance of a `foo.op` operation that has inferred result types
+    // (using the InferTypeOpInterface).
+    %op = pdl_interp.create_operation "foo.op"(%arg0 : !pdl.value) {"attrA" = %attr0} -> <inferred>
     ```
   }];
 
@@ -424,22 +428,26 @@ def PDLInterp_CreateOperationOp
                        Variadic<PDL_InstOrRangeOf<PDL_Value>>:$inputOperands,
                        Variadic<PDL_Attribute>:$inputAttributes,
                        StrArrayAttr:$inputAttributeNames,
-                       Variadic<PDL_InstOrRangeOf<PDL_Type>>:$inputResultTypes);
+                       Variadic<PDL_InstOrRangeOf<PDL_Type>>:$inputResultTypes,
+                       UnitAttr:$inferredResultTypes);
   let results = (outs PDL_Operation:$resultOp);
 
   let builders = [
     OpBuilder<(ins "StringRef":$name, "ValueRange":$types,
-      "ValueRange":$operands, "ValueRange":$attributes,
-      "ArrayAttr":$attributeNames), [{
+      "bool":$inferredResultTypes, "ValueRange":$operands,
+      "ValueRange":$attributes, "ArrayAttr":$attributeNames), [{
       build($_builder, $_state, $_builder.getType<pdl::OperationType>(), name,
-            operands, attributes, attributeNames, types);
+            operands, attributes, attributeNames, types, inferredResultTypes);
     }]>
   ];
   let assemblyFormat = [{
-    $name (`(` $inputOperands^ `:` type($inputOperands) `)`)?
+    $name (`(` $inputOperands^ `:` type($inputOperands) `)`)? ``
     custom<CreateOperationOpAttributes>($inputAttributes, $inputAttributeNames)
-    (`->` `(` $inputResultTypes^ `:` type($inputResultTypes) `)`)? attr-dict
+    custom<CreateOperationOpResults>($inputResultTypes, type($inputResultTypes),
+                                     $inferredResultTypes)
+    attr-dict
   }];
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -961,33 +969,6 @@ def PDLInterp_GetValueTypeOp : PDLInterp_Op<"get_value_type", [NoSideEffect,
   ];
 }
 
-//===----------------------------------------------------------------------===//
-// pdl_interp::InferredTypesOp
-//===----------------------------------------------------------------------===//
-
-def PDLInterp_InferredTypesOp : PDLInterp_Op<"inferred_types"> {
-  let summary = "Generate a handle to a range of Types that are \"inferred\"";
-  let description = [{
-    `pdl_interp.inferred_types` operations generate handles to ranges of types
-    that should be inferred. This signals to other operations, such as
-    `pdl_interp.create_operation`, that these types should be inferred.
-
-    Example:
-
-    ```mlir
-    %types = pdl_interp.inferred_types
-    ```
-  }];
-  let results = (outs PDL_RangeOf<PDL_Type>:$result);
-  let assemblyFormat = "attr-dict";
-  let builders = [
-    OpBuilder<(ins), [{
-      build($_builder, $_state,
-            pdl::RangeType::get($_builder.getType<pdl::TypeType>()));
-    }]>
-  ];
-}
-
 //===----------------------------------------------------------------------===//
 // pdl_interp::IsNotNullOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index 6057a21939461..c24620a6729b9 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -100,11 +100,12 @@ struct PatternLowering {
                         function_ref<Value(Value)> mapRewriteValue);
 
   /// Generate the values used for resolving the result types of an operation
-  /// created within a dag rewriter region.
+  /// created within a dag rewriter region. If the result types of the operation
+  /// should be inferred, `hasInferredResultTypes` is set to true.
   void generateOperationResultTypeRewriter(
-      pdl::OperationOp op, SmallVectorImpl<Value> &types,
-      DenseMap<Value, Value> &rewriteValues,
-      function_ref<Value(Value)> mapRewriteValue);
+      pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
+      SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
+      bool &hasInferredResultTypes);
 
   /// A builder to use when generating interpreter operations.
   OpBuilder builder;
@@ -707,15 +708,16 @@ void PatternLowering::generateRewriter(
   for (Value attr : operationOp.attributes())
     attributes.push_back(mapRewriteValue(attr));
 
+  bool hasInferredResultTypes = false;
   SmallVector<Value, 2> types;
-  generateOperationResultTypeRewriter(operationOp, types, rewriteValues,
-                                      mapRewriteValue);
+  generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
+                                      rewriteValues, hasInferredResultTypes);
 
   // Create the new operation.
   Location loc = operationOp.getLoc();
   Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
-      loc, *operationOp.name(), types, operands, attributes,
-      operationOp.attributeNames());
+      loc, *operationOp.name(), types, hasInferredResultTypes, operands,
+      attributes, operationOp.attributeNames());
   rewriteValues[operationOp.op()] = createdOp;
 
   // Generate accesses for any results that have their types constrained.
@@ -825,9 +827,9 @@ void PatternLowering::generateRewriter(
 }
 
 void PatternLowering::generateOperationResultTypeRewriter(
-    pdl::OperationOp op, SmallVectorImpl<Value> &types,
-    DenseMap<Value, Value> &rewriteValues,
-    function_ref<Value(Value)> mapRewriteValue) {
+    pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
+    SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
+    bool &hasInferredResultTypes) {
   // Look for an operation that was replaced by `op`. The result types will be
   // inferred from the results that were replaced.
   Block *rewriterBlock = op->getBlock();
@@ -851,36 +853,54 @@ void PatternLowering::generateOperationResultTypeRewriter(
     return;
   }
 
-  // Check if the operation has type inference support.
-  if (op.hasTypeInference()) {
-    types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc()));
-    return;
-  }
-
-  // Otherwise, handle inference for each of the result types individually.
+  // Try to handle resolution for each of the result types individually. This is
+  // preferred over type inferrence because it will allow for us to use existing
+  // types directly, as opposed to trying to rebuild the type list.
   OperandRange resultTypeValues = op.types();
-  types.reserve(resultTypeValues.size());
-  for (const auto &it : llvm::enumerate(resultTypeValues)) {
-    Value resultType = it.value();
+  auto tryResolveResultTypes = [&] {
+    types.reserve(resultTypeValues.size());
+    for (const auto &it : llvm::enumerate(resultTypeValues)) {
+      Value resultType = it.value();
+
+      // Check for an already translated value.
+      if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
+        types.push_back(existingRewriteValue);
+        continue;
+      }
 
-    // Check for an already translated value.
-    if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
-      types.push_back(existingRewriteValue);
-      continue;
-    }
+      // Check for an input from the matcher.
+      if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
+        types.push_back(mapRewriteValue(resultType));
+        continue;
+      }
 
-    // Check for an input from the matcher.
-    if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
-      types.push_back(mapRewriteValue(resultType));
-      continue;
+      // Otherwise, we couldn't infer the result types. Bail out here to see if
+      // we can infer the types for this operation from another way.
+      types.clear();
+      return failure();
     }
+    return success();
+  };
+  if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
+    return;
 
-    // The verifier asserts that the result types of each pdl.operation can be
-    // inferred. If we reach here, there is a bug either in the logic above or
-    // in the verifier for pdl.operation.
-    op->emitOpError() << "unable to infer result type for operation";
-    llvm_unreachable("unable to infer result type for operation");
+  // Otherwise, check if the operation has type inference support itself.
+  if (op.hasTypeInference()) {
+    hasInferredResultTypes = true;
+    return;
   }
+
+  // If the types could not be inferred from any context and there weren't any
+  // explicit result types, assume the user actually meant for the operation to
+  // have no results.
+  if (resultTypeValues.empty())
+    return;
+
+  // The verifier asserts that the result types of each pdl.operation can be
+  // inferred. If we reach here, there is a bug either in the logic above or
+  // in the verifier for pdl.operation.
+  op->emitOpError() << "unable to infer result type for operation";
+  llvm_unreachable("unable to infer result type for operation");
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
index 64a2be6a37778..40e55992cd2d7 100644
--- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
+++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
@@ -47,6 +47,23 @@ static LogicalResult verifySwitchOp(OpT op) {
 // pdl_interp::CreateOperationOp
 //===----------------------------------------------------------------------===//
 
+LogicalResult CreateOperationOp::verify() {
+  if (!getInferredResultTypes())
+    return success();
+  if (!getInputResultTypes().empty()) {
+    return emitOpError("with inferred results cannot also have "
+                       "explicit result types");
+  }
+  OperationName opName(getName(), getContext());
+  if (!opName.hasInterface<InferTypeOpInterface>()) {
+    return emitOpError()
+           << "has inferred results, but the created operation '" << opName
+           << "' does not support result type inference (or is not "
+              "registered)";
+  }
+  return success();
+}
+
 static ParseResult parseCreateOperationOpAttributes(
     OpAsmParser &p,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
@@ -82,6 +99,41 @@ static void printCreateOperationOpAttributes(OpAsmPrinter &p,
   p << '}';
 }
 
+static ParseResult parseCreateOperationOpResults(
+    OpAsmParser &p,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultOperands,
+    SmallVectorImpl<Type> &resultTypes, UnitAttr &inferredResultTypes) {
+  if (failed(p.parseOptionalArrow()))
+    return success();
+
+  // Handle the case of inferred results.
+  if (succeeded(p.parseOptionalLess())) {
+    if (p.parseKeyword("inferred") || p.parseGreater())
+      return failure();
+    inferredResultTypes = p.getBuilder().getUnitAttr();
+    return success();
+  }
+
+  // Otherwise, parse the explicit results.
+  return failure(p.parseLParen() || p.parseOperandList(resultOperands) ||
+                 p.parseColonTypeList(resultTypes) || p.parseRParen());
+}
+
+static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op,
+                                          OperandRange resultOperands,
+                                          TypeRange resultTypes,
+                                          UnitAttr inferredResultTypes) {
+  // Handle the case of inferred results.
+  if (inferredResultTypes) {
+    p << " -> <inferred>";
+    return;
+  }
+
+  // Otherwise, handle the explicit results.
+  if (!resultTypes.empty())
+    p << " -> (" << resultOperands << " : " << resultTypes << ")";
+}
+
 //===----------------------------------------------------------------------===//
 // pdl_interp::ForEachOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index c2dc41a81c6f8..ad4c078f2e3a5 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -162,6 +162,10 @@ enum OpCode : ByteCodeField {
 };
 } // namespace
 
+/// A marker used to indicate if an operation should infer types.
+static constexpr ByteCodeField kInferTypesMarker =
+    std::numeric_limits<ByteCodeField>::max();
+
 //===----------------------------------------------------------------------===//
 // ByteCode Generation
 //===----------------------------------------------------------------------===//
@@ -273,7 +277,6 @@ class Generator {
   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
-  void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
@@ -723,8 +726,7 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
   LLVM_DEBUG({
     // The following list must contain all the operations that do not
     // produce any bytecode.
-    if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp,
-             pdl_interp::InferredTypesOp>(op))
+    if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
       writer.appendInline(op->getLoc());
   });
   TypeSwitch<Operation *>(op)
@@ -742,11 +744,11 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
             pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
             pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
             pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
-            pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
-            pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
-            pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
-            pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
-            pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
+            pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
+            pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
+            pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
+            pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
+            pdl_interp::SwitchResultCountOp>(
           [&](auto interpOp) { this->generate(interpOp, writer); })
       .Default([](Operation *) {
         llvm_unreachable("unknown `pdl_interp` operation");
@@ -847,7 +849,13 @@ void Generator::generate(pdl_interp::CreateOperationOp op,
   writer.append(static_cast<ByteCodeField>(attributes.size()));
   for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
     writer.append(std::get<0>(it), std::get<1>(it));
-  writer.appendPDLValueList(op.getInputResultTypes());
+
+  // Add the result types. If the operation has inferred results, we use a
+  // marker "size" value. Otherwise, we add the list of explicit result types.
+  if (op.getInferredResultTypes())
+    writer.append(kInferTypesMarker);
+  else
+    writer.appendPDLValueList(op.getInputResultTypes());
 }
 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
   // Simply repoint the memory index of the result to the constant.
@@ -955,12 +963,6 @@ void Generator::generate(pdl_interp::GetValueTypeOp op,
     writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
   }
 }
-
-void Generator::generate(pdl_interp::InferredTypesOp op,
-                         ByteCodeWriter &writer) {
-  // InferType maps to a null type as a marker for inferring result types.
-  getMemIndex(op.getResult()) = getMemIndex(Type());
-}
 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
   writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
 }
@@ -1526,30 +1528,31 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
       state.addAttribute(name, attr);
   }
 
-  for (unsigned i = 0, e = read(); i != e; ++i) {
-    if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
-      state.types.push_back(read<Type>());
-      continue;
-    }
-
-    // If we find a null range, this signals that the types are infered.
-    if (TypeRange *resultTypes = read<TypeRange *>()) {
-      state.types.append(resultTypes->begin(), resultTypes->end());
-      continue;
-    }
-
-    // Handle the case where the operation has inferred types.
+  // Read in the result types. If the "size" is the sentinel value, this
+  // indicates that the result types should be inferred.
+  unsigned numResults = read();
+  if (numResults == kInferTypesMarker) {
     InferTypeOpInterface::Concept *inferInterface =
         state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
+    assert(inferInterface &&
+           "expected operation to provide InferTypeOpInterface");
 
     // TODO: Handle failure.
-    state.types.clear();
     if (failed(inferInterface->inferReturnTypes(
             state.getContext(), state.location, state.operands,
             state.attributes.getDictionary(state.getContext()), state.regions,
             state.types)))
       return;
-    break;
+  } else {
+    // Otherwise, this is a fixed number of results.
+    for (unsigned i = 0; i != numResults; ++i) {
+      if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
+        state.types.push_back(read<Type>());
+      } else {
+        TypeRange *resultTypes = read<TypeRange *>();
+        state.types.append(resultTypes->begin(), resultTypes->end());
+      }
+    }
   }
 
   Operation *resultOp = rewriter.create(state);

diff  --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
index 9f3141aade1d8..8d1859692a305 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
@@ -127,6 +127,29 @@ module @operation_infer_types_from_otherop_results {
 
 // -----
 
+// CHECK-LABEL: module @operation_infer_types_from_interface
+module @operation_infer_types_from_interface {
+  // Unused operation that ensures the arithmetic dialect is loaded for use in the pattern.
+  arith.constant true
+
+  // CHECK: module @rewriters
+  // CHECK:   func @pdl_generated_rewriter
+  // CHECK:     %[[CST:.*]] = pdl_interp.create_operation "arith.constant" -> <inferred>
+  // CHECK:     %[[CST_RES:.*]] = pdl_interp.get_results of %[[CST]] : !pdl.range<value>
+  // CHECK:     %[[CST_TYPE:.*]] = pdl_interp.get_value_type of %[[CST_RES]] : !pdl.range<type>
+  // CHECK:     pdl_interp.create_operation "foo.op"  -> (%[[CST_TYPE]] : !pdl.range<type>)
+  pdl.pattern : benefit(1) {
+    %root = operation "foo.op"
+    rewrite %root {
+      %types = types
+      %newOp = operation "arith.constant" -> (%types : !pdl.range<type>)
+      %newOp2 = operation "foo.op" -> (%types : !pdl.range<type>)
+    }
+  }
+}
+
+// -----
+
 // CHECK-LABEL: module @replace_with_op
 module @replace_with_op {
   // CHECK: module @rewriters

diff  --git a/mlir/test/Dialect/PDLInterp/invalid.mlir b/mlir/test/Dialect/PDLInterp/invalid.mlir
new file mode 100644
index 0000000000000..e44625d91fc8f
--- /dev/null
+++ b/mlir/test/Dialect/PDLInterp/invalid.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+//===----------------------------------------------------------------------===//
+// pdl::CreateOperationOp
+//===----------------------------------------------------------------------===//
+
+pdl_interp.func @rewriter() {
+  // expected-error at +1 {{op has inferred results, but the created operation 'foo.op' does not support result type inference}}
+  %op = pdl_interp.create_operation "foo.op" -> <inferred>
+  pdl_interp.finalize
+}
+
+// -----
+
+pdl_interp.func @rewriter() {
+  %type = pdl_interp.create_type i32
+  // expected-error at +1 {{op with inferred results cannot also have explicit result types}}
+  %op = "pdl_interp.create_operation"(%type) {
+    inferredResultTypes,
+    inputAttributeNames = [],
+    name = "foo.op",
+    operand_segment_sizes = dense<[0, 0, 1]> : vector<3xi32>
+  } : (!pdl.type) -> (!pdl.operation)
+  pdl_interp.finalize
+}
+

diff  --git a/mlir/test/Dialect/PDLInterp/ops.mlir b/mlir/test/Dialect/PDLInterp/ops.mlir
index 52b711a9419ae..ef9cefe813a5b 100644
--- a/mlir/test/Dialect/PDLInterp/ops.mlir
+++ b/mlir/test/Dialect/PDLInterp/ops.mlir
@@ -6,6 +6,10 @@
 
 // -----
 
+// Unused operation to force loading the `arithmetic` dialect for the
+// test of type inferrence.
+arith.constant true
+
 func.func @operations(%attribute: !pdl.attribute,
                  %input: !pdl.value,
                  %type: !pdl.type) {
@@ -21,6 +25,9 @@ func.func @operations(%attribute: !pdl.attribute,
   // operands, and results
   %op3 = pdl_interp.create_operation "foo.op"(%input : !pdl.value) -> (%type : !pdl.type)
 
+  // inferred results
+  %op4 = pdl_interp.create_operation "arith.constant" -> <inferred>
+
   pdl_interp.finalize
 }
 

diff  --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index e1a8c6081d4e5..aed1bbc955524 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -531,6 +531,41 @@ module @ir attributes { test.check_types_1 } {
 // pdl_interp::CreateOperationOp
 //===----------------------------------------------------------------------===//
 
+// Unused operation to force loading the `arithmetic` dialect for the
+// test of type inferrence.
+arith.constant 10
+
+// Test support for inferring the types of an operation.
+module @patterns {
+  pdl_interp.func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    pdl_interp.func @success(%root : !pdl.operation) {
+      %attr = pdl_interp.create_attribute true
+      %cst = pdl_interp.create_operation "arith.constant" {"value" = %attr} -> <inferred>
+      %cstResults = pdl_interp.get_results of %cst : !pdl.range<value>
+      %op = pdl_interp.create_operation "test.success"(%cstResults : !pdl.range<value>)
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.create_op_infer_results
+// CHECK: %[[CST:.*]] = arith.constant true
+// CHECK: "test.success"(%[[CST]])
+module @ir attributes { test.create_op_infer_results } {
+  %results:2 = "test.op"() : () -> (i64, i64)
+}
+
 // -----
 
 //===----------------------------------------------------------------------===//
@@ -1181,12 +1216,6 @@ module @ir attributes { test.get_results_2 } {
 
 // Fully tested within the tests for other operations.
 
-//===----------------------------------------------------------------------===//
-// pdl_interp::InferredTypesOp
-//===----------------------------------------------------------------------===//
-
-// Fully tested within the tests for other operations.
-
 //===----------------------------------------------------------------------===//
 // pdl_interp::IsNotNullOp
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list